遵守 TDD 实现一个精简版的 HashMap

栏目: Java · 发布时间: 5年前

内容简介:上一篇文章笔者解读了 HashMap 的源码,正好趁热打铁,今天笔者抽了些时间通过 TDD 实现了一个精简版的 HashMap,经笔者测试,正常情况下效率略微逊于 HashMap。其中最难的应属红黑树,真的是极其复杂,笔者用了一个小时还没能理解其中要领,索性使用链表替代了,等有时间再静下心来把未完成的任务消灭掉。理解问题,Tasking,TDD(包含重构),这是笔者最近一直在遵守的规则,希望可以给您给来一点感悟。

上一篇文章笔者解读了 HashMap 的源码,正好趁热打铁,今天笔者抽了些时间通过 TDD 实现了一个精简版的 HashMap,经笔者测试,正常情况下效率略微逊于 HashMap。

预设计

public class SimpleHashMap<K, V> {
    public V put(K key, V value);   
    public V get(K key);   
    public V remove(K key); 
    public boolean containsKey(K key); 
    public int size();
    public Iterator<V> values();
    public void forEach(Consumer<? super K> action);
}
复制代码

Tasking

  • 无参构建 SimpleHashMap
  • 构造函数初始化 initial capacity
  • 构造函数初始化 initial capacity 和 load factor
  • initial capacity 默认使用 16
  • load factor 默认使用 0.75f
  • 初始化的 resize 门槛为 initial capacity * load factor
  • 再次 resize 门槛为 threshold = threshold << 1
  • 增加 put 接口
    • 计算 hash 值
    • 增加 hash table 用于保存数据节点.
    • 如果 hash table 的容量为 0 或者 hash table 的容量超过门槛,则设置新的 resize 门槛,并扩容和 rehash。
    • hash table 的下标为 hash & (capacity -1)
    • 扩容时需要把旧的 hash table 的数据转移到新的 hash table
    • 转移数据到新的 hash table 之前需要 rehash,rehash = entry.hash & (new_capacity -1)
    • 如果 hash 冲突,使用链表存储
    • 如果同一个 hash 冲突超过 8 次,使用红黑树存储
  • 增加 size 接口
    • 增加全局的 size 成员变量.
    • put 接口调用成功,则 size += 1.
    • remove 接口调用成功,则 size -= 1.
    • 考虑链表
    • 考虑红黑树
  • 增加 containsKey 接口
    • 通过 key 计算 hash
    • 通过 hash 计算 index
    • 通过 index 检索 key,检索到return true,否则 return false,
    • 考虑 hash table 为 null.
    • 考虑链表
    • 考虑红黑树
  • 增加 get 接口
    • 通过 key 计算 hash
    • 通过 hash 计算 index
    • 通过 index 检索 bucket
    • 如果 bucket 存在多个数据节点,则需要判断 key 的值和引用是否相等.
    • 如果相等返回对应的 value,否则返回 null.
    • 考虑链表
    • 考虑红黑树
  • 增加 remove 接口
    • 通过 key 计算 hash
    • 通过 hash 计算 index
    • 通过 index 检索 bucket
    • 如果相等则将对应的 bucket 置 null,并返回对应的 value,否则返回 null,
    • 考虑链表
    • 考虑红黑树
  • 增加 values 接口
    • 每次 put 成功时保存 list 中到
    • 每次 put 替换成功时,需要替换 list 中对应的 value
    • 每次 remove 成功时从 list 中到删除
    • 考虑链表
    • 考虑红黑树
  • 增加 forEach 接口
    • 遍历 hash table
    • 如果存在 bucket,则通过 action.apply(key)
    • 考虑链表
    • 考虑红黑树
  • 增加 fail-fast
    • 增加 modCount 成员变量用于统计变更次数
    • 迭代前后需要验证 modCount 前后是否一致
    • 如果 modCount 前后是否一致需要抛出 ConcurrentModificationException.
  • 增加 rb tree 保存 hash 冲突超过 8 次的数据节点.

测试覆盖率

遵守 TDD 实现一个精简版的 HashMap

测试代码

**
 * @author lyning
 */
public class SimpleHashMapTest {

    private SimpleHashMap<Integer, Integer> map;

    @BeforeEach
    public void setUp() throws Exception {
        // given
        this.map = new SimpleHashMap<>();
    }

    /************ size test start **********/
    @Test
    @DisplayName("given empty entries" +
            "when call size() " +
            "then return 0")
    public void size1() {
        // when
        int size = map.size();
        // then
        assertThat(size).isZero();
    }

    @Test
    @DisplayName("given multiple entries(contains duplicate key) " +
            "when call size() " +
            "then return correct size")
    public void size2() {
        // given
        SimpleHashMap<Integer, Integer> map = new SimpleHashMap<>();
        map.put(1, 1);
        map.put(2, 2);
        map.put(3, 3);
        map.put(3, 4);
        map.put(3, 5);
        map.put(4, 4);
        map.put(5, 5);
        map.remove(1);
        map.remove(2);
        // when
        int size = map.size();
        // then
        assertThat(size).isEqualTo(3);
    }

    @Test
    @DisplayName("given multiple entries(hash conflict) " +
            "when call size() " +
            "then return correct size")
    public void size3() {
        // given
        SimpleHashMap<HashConflict, Integer> map = new SimpleHashMap<>();
        map.put(new HashConflict(1), 1);
        map.put(new HashConflict(2), 2);
        map.put(new HashConflict(3), 3);
        map.put(new HashConflict(4), 4);
        map.put(new HashConflict(5), 5);
        map.remove(new HashConflict(5));
        map.remove(new HashConflict(3));
        // when
        int size = map.size();
        // then
        assertThat(size).isEqualTo(3);
    }
    /************ size test end **********/


    /************ put test start **********/
    @Test
    @DisplayName("given empty entries " +
            "when put one entry " +
            "then return size 1")
    public void put1() {
        // when
        map.put(1, 1);
        // then
        assertThat(map.size()).isOne();
    }

    @Test
    @DisplayName("given empty entries " +
            "when put two entries(duplicate key) " +
            "then return size 1")
    public void put2() {
        // when
        map.put(1, 1);
        map.put(1, 2);
        // then
        assertThat(map.size()).isEqualTo(1);
    }

    @Test
    @DisplayName("given empty entries " +
            "when put three entries " +
            "then return size 3")
    public void put3() {
        // when
        map.put(1, 1);
        map.put(2, 2);
        map.put(3, 3);
        // then
        assertThat(map.size()).isEqualTo(3);
    }

    @Test
    @DisplayName("should return value " +
            "when call put")
    public void put4() {
        // when
        Integer value = map.put(1, 1);
        // then
        assertThat(value).isEqualTo(1);
    }

    @Test
    @DisplayName("given empty entries " +
            "when put multiples entries(hash conflict) " +
            "then")
    public void put5() {
        // given
        SimpleHashMap<HashConflict, Integer> map = new SimpleHashMap<>();
        // when
        map.put(new HashConflict(1), 1);
        map.put(new HashConflict(2), 2);
        map.put(new HashConflict(3), 3);
        map.put(new HashConflict(3), 4);
        map.put(new HashConflict(3), 5);
        map.put(new HashConflict(4), 4);
        map.put(new HashConflict(5), 5);
        // then
        assertThat(Lists.newArrayList(map.values())).isEqualTo(Lists.list(1, 2, 5, 4, 5));
    }

    @Test
    @DisplayName("should auto grow " +
            "when capacity exceed threshold")
    public void put6() {
        // given default threshold = 8
        // when
        for (int i = 1; i <= 20; i++) {
            map.put(i, i);
        }
        // then
        assertThat(map.size()).isEqualTo(20);
        assertThat(map.get(20)).isEqualTo(20);
    }
    /************ put test end **********/

    /************ get test start **********/
    @Test
    @DisplayName("given empty entries" +
            "when get by null key" +
            "then return null")
    public void get1() {
        // when
        Integer value = map.get(null);
        // then
        assertThat(value).isNull();
    }

    @Test
    @DisplayName("given empty entries" +
            "when get value by not exist key" +
            "then return null")
    public void get2() {
        // when
        Integer value = map.get(2);
        // then
        assertThat(value).isNull();
    }

    @Test
    @DisplayName("given entry" +
            "when get value by not exist key" +
            "then return null")
    public void get3() {
        // given
        map.put(1, 1);
        // when
        Integer value = map.get(2);
        // then
        assertThat(value).isNull();
    }

    @Test
    @DisplayName("given entry" +
            "when get value" +
            "then return value")
    public void get4() {
        // given
        map.put(1, 1);
        // when
        Integer value = map.get(1);
        // then
        assertThat(value).isEqualTo(1);
    }

    @Test
    @DisplayName("given multiple entries(hash conflict)" +
            "when get value by hash conflict key" +
            "then return value")
    public void get5() {
        // given
        SimpleHashMap<HashConflict, Integer> map = new SimpleHashMap<>();
        map.put(new HashConflict(1), 1);
        map.put(new HashConflict(2), 2);
        map.put(new HashConflict(3), 3);
        map.put(new HashConflict(3), 4);
        map.put(new HashConflict(3), 5);
        map.put(new HashConflict(4), 4);
        map.put(new HashConflict(5), 5);
        // when
        Integer value = map.get(new HashConflict(3));
        // then
        assertThat(value).isEqualTo(5);
    }

    @Test
    @DisplayName("given multiple entries(hash conflict)" +
            "when get value by not exist hash conflict key" +
            "then return null")
    public void get6() {
        // given
        SimpleHashMap<HashConflict, Integer> map = new SimpleHashMap<>();
        map.put(new HashConflict(1), 1);
        map.put(new HashConflict(2), 2);
        map.put(new HashConflict(3), 3);
        map.put(new HashConflict(4), 4);
        map.put(new HashConflict(5), 5);
        // when
        Integer value = map.get(new HashConflict(6));
        // then
        assertThat(value).isNull();
    }
    /************ get test end **********/


    /************ remove test start **********/
    @Test
    @DisplayName("given empty entries" +
            "when remove by null key" +
            "then return null")
    public void remove1() {
        // when
        Integer value = map.remove(null);
        // then
        assertThat(value).isNull();
    }

    @Test
    @DisplayName("given entry" +
            "when remove by null key" +
            "then return null")
    public void remove2() {
        // given
        map.put(1, 1);
        // when
        Integer value = map.remove(null);
        // then
        assertThat(value).isNull();
    }

    @Test
    @DisplayName("given entry" +
            "when remove by key" +
            "then return value")
    public void remove3() {
        // given
        map.put(1, 1);
        // when
        int value = map.remove(1);
        // then
        assertThat(value).isEqualTo(1);
    }

    @Test
    @DisplayName("given entry" +
            "when remove by not exist key" +
            "then return null")
    public void remove4() {
        // given
        map.put(1, 1);
        // when
        Integer value = map.remove(2);
        // then
        assertThat(value).isNull();
    }

    @Test
    @DisplayName("given multiple entries(hash conflict)" +
            "when remove by hash conflict key" +
            "then return value")
    public void remove5() {
        // given
        SimpleHashMap<HashConflict, Integer> map = new SimpleHashMap<>();
        map.put(new HashConflict(1), 1);
        map.put(new HashConflict(2), 2);
        map.put(new HashConflict(3), 3);
        map.put(new HashConflict(4), 4);
        map.put(new HashConflict(5), 5);
        // when
        Integer value = map.remove(new HashConflict(3));
        // then
        assertThat(value).isEqualTo(3);
        assertThat(Lists.newArrayList(map.values())).isEqualTo(Lists.list(1, 2, 4, 5));
    }
    /************ remove test end **********/


    /************ values test start **********/
    @Test
    @DisplayName("given empty entries" +
            "when call values" +
            "then return empty values")
    public void values1() {
        // when
        Iterable<Integer> values = map.values();
        // then
        assertThat(values).isEmpty();
    }

    @Test
    @DisplayName("given multiple entries" +
            "when call values" +
            "then return all values")
    public void values2() {
        // given
        map.put(1, 1);
        map.put(2, 2);
        map.put(3, 3);
        map.put(3, 4);
        map.put(4, 4);
        map.remove(4);
        // when
        Iterable<Integer> values = map.values();
        // then
        assertThat(values.spliterator().estimateSize()).isEqualTo(3);
        assertThat(Lists.newArrayList(values)).isEqualTo(Lists.list(1, 2, 4));
    }
    /************ values test end **********/


    /************ containsKey test start **********/
    @Test
    @DisplayName("given entry" +
            "when key exist" +
            "then return true")
    public void contains_key1() {
        // given
        map.put(1, 1);
        // when
        boolean result = map.containsKey(1);
        // then
        assertThat(result).isTrue();
    }

    @Test
    @DisplayName("given entry" +
            "when key not exist" +
            "then return false")
    public void containsKey2() {
        // given
        map.put(1, 1);
        // when
        boolean result = map.containsKey(2);
        // then
        assertThat(result).isFalse();
    }

    @Test
    @DisplayName("given multiple entries(hash conflict)" +
            "when call containsKey" +
            "then return correct result")
    public void containsKey3() {
        // given
        SimpleHashMap<HashConflict, Integer> map = new SimpleHashMap<>();
        map.put(new HashConflict(1), 1);
        map.put(new HashConflict(2), 2);
        map.put(new HashConflict(3), 3);
        map.put(new HashConflict(4), 4);
        map.put(new HashConflict(5), 5);
        // then
        assertThat(map.containsKey(new HashConflict(3))).isTrue();
        assertThat(map.containsKey(new HashConflict(5))).isTrue();
        assertThat(map.containsKey(new HashConflict(6))).isFalse();
    }
    /************ containsKey test end **********/


    /************ forEach test start **********/
    @Test
    @DisplayName("given multiple entries" +
            "when call forEach" +
            "then pass")
    public void forEach1() {
        // given
        map.put(1, 1);
        map.put(2, 2);
        map.put(3, 3);
        map.put(4, 4);
        // when
        List<Integer> results = new ArrayList<>();
        map.forEach((key) -> results.add(map.get(key)));
        // then
        assertThat(results).isEqualTo(Lists.list(1, 2, 3, 4));
    }

    @Test
    @DisplayName("given multiple entries(hash conflict)" +
            "when call forEach" +
            "then pass")
    public void forEach2() {
        // given
        SimpleHashMap<HashConflict, Integer> map = new SimpleHashMap<>();
        map.put(new HashConflict(1), 1);
        map.put(new HashConflict(2), 2);
        map.put(new HashConflict(3), 3);
        map.put(new HashConflict(4), 4);
        map.put(new HashConflict(5), 5);
        // when
        List<Integer> results = new ArrayList<>();
        map.forEach((key) -> results.add(map.get(key)));
        // then
        assertThat(results).isEqualTo(Lists.list(1, 2, 3, 4, 5));
    }

    /************ forEach test end **********/

    class HashConflict {
        private int field;

        HashConflict(int field) {
            this.field = field;
        }

        @Override
        public int hashCode() {
            return this.field <= 8 ? 1 : this.field;
        }

        @Override
        public boolean equals(Object obj) {
            return ((HashConflict) obj).field == this.field;
        }
    }
}
复制代码

SimpleHashMap 源码

/**
 * @author lyning
 */
public class SimpleHashMap<K, V> {
    private static final int DEFAULT_INITIAL_CAPACITY = 16;
    private static final float DEFAULT_LOAD_FACTOR = 0.75f;
    private int size;
    private Bucket<K, V>[] table;
    private int threshold;

    public boolean containsKey(K key) {
        int hash = this.hash(key);
        int index = this.index(hash);
        Bucket<K, V> bucket = this.table[index];
        return bucket != null
                && bucket.lookup(key) != null;
    }

    public void forEach(Consumer<K> action) {
        for (Bucket<K, V> bucket : this.table) {
            while (bucket != null) {
                action.accept(bucket.key);
                bucket = bucket.next;
            }
        }
    }

    public V get(K key) {
        if (this.tableEmpty()) {
            return null;
        }
        int hash = this.hash(key);
        int index = this.index(hash);
        return this.getVal(index, key);
    }

    public V put(K key, V value) {
        if (this.tableEmpty() || this.nearByThreshold()) {
            this.resize();
        }
        int hash = this.hash(key);
        return this.putVal(key, value, hash);
    }

    public V remove(K key) {
        if (this.tableEmpty()) {
            return null;
        }
        int hash = this.hash(key);
        int index = this.index(hash);
        return this.removeVal(index, key);
    }

    public int size() {
        return this.size;
    }

    public Iterable<V> values() {
        if (this.tableEmpty()) {
            return new ArrayList<>();
        }
        List<V> collections = new ArrayList<>();
        this.collectValues(collections);
        return collections;
    }

    private void collectValues(List<V> collections) {
        for (Bucket<K, V> bucket : this.table) {
            while (bucket != null) {
                collections.add(bucket.value);
                bucket = bucket.next;
            }
        }
    }

    private Bucket<K, V> findBucket(int index) {
        return this.table[index];
    }

    private V getVal(int index, K key) {
        Bucket<K, V> bucket = this.findBucket(index);
        if (Objects.isNull(bucket) || Objects.isNull(bucket = bucket.lookup(key))) {
            return null;
        }
        return bucket.value;
    }

    private void grow(int newCap) {
        if (this.tableEmpty()) {
            this.initTable(newCap);
            return;
        }
        this.table = this.rebuildTable(newCap);
    }

    private int hash(K key) {
        int hashcode;
        return key == null
                ? 0
                : (hashcode = key.hashCode()) ^ (hashcode >>> 16);
    }

    private int index(int hash) {
        return hash & (this.table.length - 1);
    }

    private void initTable(int newCap) {
        this.table = new Bucket[newCap];
    }

    private boolean nearByThreshold() {
        return this.size + 1 >= this.threshold;
    }

    private V putVal(K key, V value, int hash) {
        int index = this.index(hash);
        Bucket<K, V> bucket = this.table[index];

        if (Objects.isNull(bucket)) {
            this.table[index] = new Bucket<>(hash, key, value);
        } else {
            Bucket<K, V> indexBucket = bucket.lookup(key);
            if (indexBucket != null) {
                indexBucket.value = value;
                return value;
            }
            bucket.putLast(new Bucket<>(hash, key, value));
        }
        this.size += 1;
        return value;
    }

    private Bucket<K, V>[] rebuildTable(int newCap) {
        Bucket<K, V>[] oldTable = this.table;
        Bucket<K, V>[] newTable = new Bucket[newCap];
        for (Bucket<K, V> bucket : oldTable) {
            if (bucket != null) {
                int index = this.index(bucket.hash);
                newTable[index] = bucket;
            }
        }
        return newTable;
    }

    private V removeVal(int index, K key) {
        Bucket<K, V> bucket = this.findBucket(index);
        Bucket<K, V> prev = null;
        while (bucket != null) {
            if (bucket.matchKey(key)) {
                if (Objects.isNull(prev)) {
                    this.table[index] = null;
                } else {
                    prev.next = bucket.next;
                }
                this.size -= 1;
                return bucket.value;
            }
            prev = bucket;
            bucket = bucket.next;
        }
        return null;
    }

    private void resize() {
        int oldCap = this.tableCapacity();
        int newCap = 0;
        if (oldCap == 0) {
            oldCap = DEFAULT_INITIAL_CAPACITY;
            this.threshold = (int) (DEFAULT_INITIAL_CAPACITY * DEFAULT_LOAD_FACTOR);
        } else {
            newCap = oldCap << 1;
            this.threshold = this.threshold << 1;
        }

        if (newCap == 0) {
            newCap = oldCap;
        }
        this.grow(newCap);
    }

    private int tableCapacity() {
        return Objects.isNull(this.table) ? 0 : this.table.length;
    }

    private boolean tableEmpty() {
        return Objects.isNull(this.table);
    }

    static class Bucket<K, V> {
        Bucket<K, V> next;
        int hash;
        K key;
        V value;

        public Bucket(int hash, K key, V value) {
            this.hash = hash;
            this.key = key;
            this.value = value;
        }

        public Bucket<K, V> lookup(K key) {
            Bucket<K, V> bucket = this;
            while (bucket != null) {
                if (bucket.matchKey(key)) {
                    return bucket;
                }
                bucket = bucket.next;
            }
            return null;
        }

        public boolean matchKey(K key) {
            return this.key == key || this.key.equals(key);
        }

        public void putLast(Bucket<K, V> bucket) {
            this.last().next = bucket;
        }

        private Bucket last() {
            Bucket<K, V> bucket = this;
            while (true) {
                if (Objects.isNull(bucket.next)) {
                    return bucket;
                }
                bucket = bucket.next;
            }
        }
    }
}
复制代码

总结

其中最难的应属红黑树,真的是极其复杂,笔者用了一个小时还没能理解其中要领,索性使用链表替代了,等有时间再静下心来把未完成的任务消灭掉。

理解问题,Tasking,TDD(包含重构),这是笔者最近一直在遵守的规则,希望可以给您给来一点感悟。

源码


以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持 码农网

查看所有标签

猜你喜欢:

本站部分资源来源于网络,本站转载出于传递更多信息之目的,版权归原作者或者来源机构所有,如转载稿涉及版权问题,请联系我们

跟小贤学运营

跟小贤学运营

陈维贤 / 机械工业出版社 / 2016-12-9 / 69.00

这是一部能帮助运营新人快速构建互联网运营方法论和快速掌握互联网运营实操的著作,是小贤在百度贴吧和小红书成长经历和运营经验的复盘。书中包含5大运营主题、40余种运营工具和渠道、50余种运营方法和技巧、100余个真实接地气的运营案例,能迅速帮助运营新人掌握全套实操技能和构建完整运营体系。 本书的视角和知识体系都比较立体化: 既有百度这样的互联网巨头运营规范和思路,又有小红书这样的明星创业公......一起来看看 《跟小贤学运营》 这本书的介绍吧!

HTML 压缩/解压工具
HTML 压缩/解压工具

在线压缩/解压 HTML 代码

CSS 压缩/解压工具
CSS 压缩/解压工具

在线压缩/解压 CSS 代码

SHA 加密
SHA 加密

SHA 加密工具