Categories:

Tags:



Problem

Design a HashSet without using any built-in hash table libraries.

Implement MyHashSet class:

  • void add(key) Inserts the value key into the HashSet.
  • bool contains(key) Returns whether the value key exists in the HashSet or not.
  • void remove(key) Removes the value key in the HashSet. If key does not exist in the HashSet, do nothing.

Constraints

  • 0 <= key <= <= 106
  • At most 104 calls will be made to add, remove, and contains.

Solution

The problem Design HashSet can be solved by implementing a hash function. I chose division hashing for hashing and separate chaining for collision resolution. The divisor was chosen based on the maximum number of 104 elements that could be inserted in the hash table and the optimal load factor of 75%. More specifically, 13337 was chosen as the divisor since it was the closest prime number to 10000 / 0.75 ≃ 13333 and wasn’t close to the power of 2. Finally, I implemented an AVL tree to use as the data structure for the bucket in separate chaining.

Implementation

static const int fast_io = []()
{
    std::ios::sync_with_stdio(false);
    cin.tie(NULL);
    cout.tie(NULL);
    return 0;
}();

class Node
{
  public:
    int key;
    int height;
    Node *left;
    Node *right;

    Node() : key(0), left(NULL), right(NULL) { height = 1; }
    Node(int key) : key(key), left(NULL), right(NULL) { height = 1; }
    Node(int key, Node *left, Node *right) : key(key), left(left), right(right) { height = 1; }
};

class AVLTree
{
  private:
    Node *root;

    int getHeight(Node *node) { return node == NULL ? 0 : node->height; }

    Node *leftRotate(Node *root)
    {
        Node *rnode = root->right;
        root->right = rnode->left;
        rnode->left = root;

        root->height = max(getHeight(root->right), getHeight(root->right)) + 1;
        rnode->height = max(getHeight(rnode->left), getHeight(rnode->left)) + 1;

        return rnode;
    }

    Node *rightRotate(Node *root)
    {
        Node *lnode = root->left;
        root->left = lnode->right;
        lnode->right = root;

        root->height = max(getHeight(root->left), getHeight(root->right)) + 1;
        lnode->height = max(getHeight(lnode->left), getHeight(lnode->right)) + 1;

        return lnode;
    }

    Node *insert(Node *node, int key)
    {
        if (node == NULL)
            return new Node(key);

        if (key < node->key)
            node->left = insert(node->left, key);
        else if (key > node->key)
            node->right = insert(node->right, key);
        else
            return node;

        int lheight = getHeight(node->left);
        int rheight = getHeight(node->right);
        node->height = max(lheight, rheight) + 1;

        if (lheight - rheight > 1 && key < node->left->key)
            return rightRotate(node);
        if (rheight - lheight > 1 && key > node->right->key)
            return leftRotate(node);
        if (lheight - rheight > 1 && key > node->left->key)
        {
            node->left = leftRotate(node->left);
            return rightRotate(node);
        }
        if (rheight - lheight > 1 && key < node->right->key)
        {
            node->right = rightRotate(node->right);
            return leftRotate(node);
        }

        return node;
    }

    Node *erase(Node *node, int key)
    {
        if (node == NULL)
            return node;

        if (key < node->key)
            node->left = erase(node->left, key);
        else if (key > node->key)
            node->right = erase(node->right, key);
        else
        {
            Node *old = node;
            if (node->left == NULL)
                node = node->right;
            else if (node->right == NULL)
                node = node->left;
            else
            {
                Node *next = node->right;
                while (next->left)
                    next = next->left;

                node = new Node(next->key);
                node->left = old->left;
                node->right = erase(old->right, next->key);
            }
            delete old;
        }

        if (node == NULL)
            return node;

        int lheight = getHeight(node->left);
        int rheight = getHeight(node->right);
        node->height = max(lheight, rheight) + 1;

        if (lheight - rheight > 1 && getHeight(node->left->left) >= getHeight(node->left->right))
            return rightRotate(node);
        if (rheight - lheight > 1 && getHeight(node->right->left) <= getHeight(node->right->right))
            return leftRotate(node);
        if (lheight - rheight > 1 && getHeight(node->left->left) < getHeight(node->left->right))
        {
            node->left = leftRotate(node->left);
            return rightRotate(node);
        }
        if (rheight - lheight > 1 && getHeight(node->right->left) > getHeight(node->right->right))
        {
            node->right = rightRotate(node->right);
            return leftRotate(node);
        }

        return node;
    }

    Node *find(Node *node, int key)
    {
        if (node == NULL)
            return node;

        if (key < node->key)
            return find(node->left, key);
        else if (key > node->key)
            return find(node->right, key);

        return node;
    }

  public:
    AVLTree() : root(NULL) {}

    void insert(int key) { root = insert(root, key); }

    void erase(int key) { root = erase(root, key); }

    bool find(int key) { return find(root, key); }
};

class MyHashSet
{
  private:
    vector<AVLTree> bucket;

  public:
    MyHashSet() { bucket.resize(13337); }

    void add(int key) { bucket[key % 13337].insert(key); }

    void remove(int key) { bucket[key % 13337].erase(key); }

    bool contains(int key) { return bucket[key % 13337].find(key); }
};

/**
 * Your MyHashSet object will be instantiated and called as such:
 * MyHashSet* obj = new MyHashSet();
 * obj->add(key);
 * obj->remove(key);
 * bool param_3 = obj->contains(key);
 */