Trie Symbol Table

Posted by Dustin Boston in .


The Trie Symbol Table is a data structure designed to efficiently store and retrieve key-value pairs where the keys are strings. It organizes data into a tree-like structure, where each node represents a character, and paths from the root to nodes represent strings.

Source Code Listing

trie_symbol_table.ts

/*!
 * @file Implements Trie Symbol Table data structure in TypeScript.
 * @author Dustin Boston <mail@dustin.boston>
 * @see Sedgewick, R., & Wayne, K. (2011). _Algorithms_ (4th ed.). Addison-Wesley.
 * @see https://algs4.cs.princeton.edu/52trie/TrieST.java.html
 * @see https://learnersbucket.com/tutorials/data-structures/trie-data-structure-in-javascript/
 */

import {UnicodeString} from "./util/unicode_string.ts";

export type Maybe<T> = T | undefined;

/** R-way trie node */
export class Node<T> {
  next: Array<Node<T>> = [];
  value?: T = undefined;
}

/**
 *  The TrieST class represents a symbol table of key-value pairs, with string
 *  keys and generic values.
 *
 *  A symbol table implements the _associative array_ abstraction:
 *  when associating a value with a key that is already in the symbol table,
 *  the convention is to replace the old value with the new value.
 *
 *  Values cannot be `undefined`—setting the value associated with a key to
 *  `undefined` is equivalent to deleting the key from the symbol table.
 *
 *  This implementation uses an r-way trie that adjusts its size to the
 *  alphabet that is provided. The `put`, `contains`, `delete`, and
 *  `longestPrefixOf` operations take time proportional to the length of the
 *  key (in the worst case). Everything else is constant time.
 */
export class TrieSymbolTable<T> {
  /**
   * Max code point of the Basic Multilingual Plane (UTF-16). Points above
   * this value (e.g. Emoji), called surrogate pairs, use two code points.
   * String iteration must check this to decide if the next index is 1 or 2.
   */
  readonly #bmpMax = 0xff_ff;

  /** Max Unicode code point. */
  readonly #radix = 0x10_ff_ff;

  /** The root node in the trie */
  #root?: Node<T>;

  /** The number of key-value pairs in this symbol table. */
  #size = 0;

  /**
   * Retrieves the value associated with a given key.
   * @param key - The key for which to retrieve the value.
   * @returns The value associated with the key, or undefined if not found.
   */
  get(key: string): Maybe<T> {
    if (!key) throw new TypeError("Undefined key.");
    const uniKey = new UnicodeString(key);
    return this.#get(this.#root, uniKey)?.value;
  }

  /**
   * Recursive method to retrieve a node corresponding to the key.
   * @param node - The current node in the trie.
   * @param key - The key to search for.
   * @param index - The index of the character in the key.
   * @returns A node corresponding to the key or undefined if not found.
   */
  #get(node: Maybe<Node<T>>, key: UnicodeString): Maybe<Node<T>> {
    if (node?.next === undefined) return;

    const iter = key.next();
    if (iter.done || iter.value === undefined) return node;
    return this.#get(node.next[iter.value], key);
  }

  /**
   * Inserts the key-value pair into the symbol table. If the key exists
   * the value will be overwritten. An undefined value will delete the key.
   * @param key - The key to insert or update.
   * @param value - Value to associate with the key, undefined to delete it.
   * @throws if the key is undefined.
   */
  put(key: string, value: Maybe<T>) {
    if (!key) throw new TypeError("Undefined key.");
    if (value === undefined) this.delete(key);
    else this.#root = this.#put(this.#root, key, value, 0);
  }

  /**
   * Recursive method to insert or update a node in the symbol table.
   * @param node - The current node.
   * @param key - The key to insert or update.
   * @param value - The value to associate with the key.
   * @param index - The index of the character in the key.
   * @returns The updated or newly inserted node.
   * @throws if the code point for the given index is undefined.
   */
  #put(node: Maybe<Node<T>>, key: string, value: T, index: number): Node<T> {
    node ||= new Node<T>();
    if (index === key.length) {
      if (node.value === undefined) this.#size++;
      node.value = value;
      return node;
    }

    // Use the key char at index to identify subtrie.
    const codePoint = key.codePointAt(index);
    if (codePoint === undefined) throw new Error(`Invalid index.`);

    const nextIndex = index + (codePoint > this.#bmpMax ? 2 : 1);

    node.next[codePoint] = this.#put(
      node.next[codePoint],
      key,
      value,
      nextIndex,
    );

    return node;
  }

  /**
   * Gets number of key-value pairs in this symbol table.
   * @returns The number of key-value pairs in this symbol table.
   */
  get size() {
    return this.#size;
  }

  /**
   * Whether the symbol table is empty.
   * @returns true if size is 0, otherwise false.
   */
  get isEmpty() {
    return this.#size === 0;
  }

  /**
   * Does this symbol table contain the given key?
   * @param key - the key to search for in the symbol table.
   * @returns true if this symbol table contains the key, otherwise false.
   * @throws if key is undefined.
   */
  contains(key: string) {
    if (!key) throw new TypeError("Undefined key.");
    return this.get(key) !== undefined;
  }

  /**
   * Gets all keys in the symbol table.
   * @returns an array of keys.
   */
  keys(): string[] {
    return this.keysWithPrefix("");
  }

  /**
   * Returns all of the keys in the set that start with `prefix`.
   * This is iterative to prevent stack overflows.
   * @param prefix - The prefix to search for.
   * @returns All keys in the set that start with `prefix`
   */
  keysWithPrefix(prefix: string) {
    const results: string[] = [];
    const node = this.#get(this.#root, new UnicodeString(prefix));
    this.#collectPrefixes(node, prefix, results);
    return results;
  }

  #collectPrefixes(node: Maybe<Node<T>>, prefix: string, results: string[]) {
    if (!node) return;
    if (node.value !== undefined) results.push(prefix);
    for (let codePoint = 0; codePoint <= this.#radix; codePoint++) {
      const newPrefix = prefix.concat(String.fromCodePoint(codePoint));
      this.#collectPrefixes(node.next[codePoint], newPrefix, results);
    }
  }

  // KeysWithPrefix(prefix: string): string[] {
  // 	const prefixQueue: string[] = [];
  // 	const rootNode = this.#get(this.#root, prefix, 0);
  // 	const todoStack: [Node<T> | undefined, string][] = [
  // 		[rootNode, prefix],
  // 	];

  // 	while (todoStack.length) {
  // 		const todo = todoStack.pop();
  // 		if (todo === undefined) continue;
  // 		const [node, partialPrefix] = todo;

  // 		if (node === undefined) continue;
  // 		if (node.value !== undefined) {
  // 			prefixQueue.push(
  // 				prefix.length ? partialPrefix : prefix + partialPrefix,
  // 			);
  // 		}

  // 		// Iterate in reverse to evaluate in order
  // 		let codePoint = this.#radix;
  // 		while (codePoint >= 0) {
  // 			todoStack.push([
  // 				node.next[codePoint],
  // 				partialPrefix + String.fromCodePoint(codePoint),
  // 			]);
  // 			codePoint--;
  // 		}
  // 	}

  // 	return prefixQueue;
  // }

  keysThatMatch(pattern: string) {
    const results: string[] = [];
    this.#collectKeys(this.#root, "", pattern, results);
    return results;
  }

  /**
   * Find keys that match a pattern, where "." is a wild card.
   * Will only find keys of the same length as the pattern.
   * @param node - The node to search for matching keys.
   * @param prefix - The prefix to search for.
   * @param pattern - the pattern to search for.
   * @returns all keys that match the pattern.
   */
  #collectKeys(
    node: Maybe<Node<T>>,
    prefix: string,
    pattern: string,
    results: string[],
  ) {
    if (!node) return;
    const index = prefix.length;
    if (index === pattern.length) {
      if (node.value === undefined) {
        return;
      }

      results.push(prefix);
    }

    const codePoint = pattern.codePointAt(index);
    if (codePoint === undefined) throw new RangeError("Invalid index.");

    const char = String.fromCodePoint(codePoint);
    if (char === ".") {
      for (let n = 0; n <= this.#radix; ) {
        prefix = prefix.concat(String.fromCodePoint(n));
        this.#collectKeys(node.next[n], prefix, pattern, results);
        n += n > 0xff_ff ? 2 : 1;
      }
    } else {
      prefix = prefix.concat(char);
      this.#collectKeys(node.next[codePoint], prefix, pattern, results);
    }
  }

  /**
   * Find keys that match a pattern, where "." is a wild card.
   * Will only find keys of the same length as the pattern.
   * This is iterative to prevent the stack from overflowing.
   * @param pattern - the pattern to search for.
   * @returns all keys that match the pattern.
   */
  // keysThatMatch(pattern: string): string[] {
  // 	const prefixQueue: string[] = [];
  // 	const prefix: string = '';
  // 	const patternLength: number = pattern.length;
  // 	const root = this.#get(this.#root, new UnicodeString(prefix));
  // 	const todos: [Node<T> | undefined, string][] = [[root, prefix]];

  // 	while (todos.length) {
  // 		const todo = todos.pop();
  // 		if (!todo || !todo[0]) continue;
  // 		const [node, partialPrefix] = todo;

  // 		const prefixLength = prefix.length + partialPrefix.length;
  // 		if (prefixLength === patternLength && node.value !== undefined) {
  // 			prefixQueue.push(prefix + partialPrefix);
  // 		}

  // 		if (prefixLength >= patternLength) continue;

  // 		// Handle character retrieval considering surrogate pairs
  // 		const codePoint = pattern.codePointAt(prefixLength);
  // 		if (codePoint === undefined) continue;

  // 		const nextChar = String.fromCodePoint(codePoint);
  // 		const nextIndex = codePoint > 0xFFFF ? 2 : 1; // Adjust index increment for surrogate pairs

  // 		// Iterate over possible character codes
  // 		let nextCodePoint = this.#radix - 1;

  // 		while (nextCodePoint >= 0) {
  // 			const char = String.fromCodePoint(nextCodePoint);
  // 			if (nextChar === '.' || nextChar === char) {
  // 				todos.push([
  // 					node.next[nextCodePoint],
  // 					partialPrefix + char,
  // 				]);
  // 			}
  // 			nextCodePoint -= nextIndex;
  // 		}
  // 	}
  // 	return prefixQueue;
  // }

  /**
   * Finds the string in the symbol table that is the longest prefix of
   * `query`. This is iterative to prevent a stack overflow.
   * @param query - the string to search for.
   * @returns the longest prefix of query in the symbol table.
   */
  longestPrefixOf(query: string): string {
    let length = 0;
    let index = 0;

    const queryLength: number = query.length;
    const processStack: Array<Node<T> | undefined> = [];
    const rootNode = this.#get(this.#root, "", 0);

    processStack.push(rootNode);

    while (processStack.length > 0) {
      const node = processStack.pop();

      if (node === undefined) continue;
      if (node.value !== undefined) length = index;
      if (index === queryLength) continue;

      const charIndex = this._alphabet.toIndex(query.charAt(index));
      processStack.push(node.next[charIndex]);
      index++;
    }

    return query.slice(0, Math.max(0, length));
  }

  /**
   * Removes the key from the set if the key is present.
   * @param key - the key to remove.
   */
  delete(key: string) {
    if (key === undefined) throw new Error("A key must be specified.");
    this.#root = this.deleteNode(this.#root, key, 0);
  }

  /**
   * Removes the key from the set if it is present.
   * @param node - the node to process.
   * @param key - the key to search for.
   * @param index - the index of the character in the key.
   */
  private deleteNode(
    node: Node<T> | undefined,
    key: string,
    index: number,
  ): Node<T> | undefined {
    if (node?.next === undefined) return node;

    if (index === key.length) {
      if (node.value !== undefined) this.size--;
      node.value = undefined;
    } else {
      const charIndex = this._alphabet.toIndex(key.charAt(index));
      const deletedNode = this.deleteNode(node.next[charIndex], key, index + 1);
      if (deletedNode) node.next[charIndex] = deletedNode;
    }

    // Remove subtrie rooted at `node` if it is completely empty.
    if (node.value !== undefined) return node;

    for (let subtrieIndex = 0; subtrieIndex < this.#radix; subtrieIndex++) {
      if (node.next[subtrieIndex] !== undefined) return node;
    }

    return undefined;
  }

  *#unicodeRangeIterator(
    start = 0,
    end = this.#radix,
  ): Generator<[number, string]> {
    // Ensure starting and ending points are within the Unicode range
    const safeEnd = Math.min(this.#radix, end);
    const safeStart = Math.max(0, start);

    for (let codePoint = safeStart; codePoint <= safeEnd; codePoint++) {
      yield [codePoint, String.fromCodePoint(codePoint)];
    }
  }
}

trie_symbol_table_test.ts

import {TrieSymbolTable} from "../../../../data/algorithms/src/trie_symbol_table.ts";
import {Alphabet} from "../../../../data/algorithms/src/alphabet.ts";
import {assert} from "$assert";

const alpha = Alphabet.LOWERCASE;

test(function testInsert() {
  const t = new TrieSymbolTable(alpha);
  const key = "she sells sea shells by the sea shore";
  for (const [i, k] of key.split(" ").entries()) {
    t.put(k, i);
  }

  assert(t.size() === 7, "Size of trie matches");
});

test(function testListKeys() {
  const t = new TrieSymbolTable(alpha);
  const key = "she sells sea shells by the sea shore";
  for (const [i, k] of key.split(" ").entries()) {
    t.put(k, i);
  }

  const actual = t.keys();
  assert(actual.length === 7, "Number of keys matches");
});

test(function testGetValuesByKey() {
  const wordList = [
    ["by", "b i"],
    ["sea", "s e"],
    ["sells", "s e l z"],
    ["she", "sh e"],
    ["shells", "sh e l z"],
    ["shore", "sh o r"],
    ["the", "<th"],
  ];

  const t = new TrieSymbolTable(alpha);
  for (const [key, value_] of wordList) {
    t.put(key, value_);
  }

  const value = t.get("she");
});

test(function testGetKeysWithPrefix() {
  const t = new TrieSymbolTable(alpha);
  const wordList = "she sells sea shells by the sea shore";
  for (const [i, k] of wordList.split(" ").entries()) {
    t.put(k, i);
  }

  const keys = t.keysWithPrefix("sh");
  assert(keys.join(",") === "she,shells,shore");
});

test(function testMatchKeys() {
  const t = new TrieSymbolTable(alpha);
  const wordList = "she sells sea shells by the sea shore";
  for (const [i, k] of wordList.split(" ").entries()) {
    t.put(k, i);
  }

  const keys = t.keysThatMatch("sh.");
  assert(keys.join(",") === "she");
});

test(function testFindTheLongestPrefixOfAString() {
  const t = new TrieSymbolTable(alpha);
  const wordList = "she sells sea shells by the sea shore";
  for (const [i, k] of wordList.split(" ").entries()) {
    t.put(k, i);
  }

  for (const [input, expected] of [
    ["she", "she"],
    ["shell", "she"],
    ["shellsort", "shells"],
    ["shelter", "she"],
  ]) {
    const actual = t.longestPrefixOf(input);
    assert(actual === expected);
  }
});

test(function testDeleteAKey() {
  const t = new TrieSymbolTable(alpha);
  const wordList = "she sells sea shells by the sea shore";
  for (const [i, k] of wordList.split(" ").entries()) {
    t.put(k, i);
  }

  t.delete("shells");

  const keys = t.keysWithPrefix("sh");
  assert(keys.join(",") === "she,shore");
});