Textrank

Posted by Dustin Boston in .

TextRank is a graph-based ranking model used for natural language processing tasks like text summarization and keyword extraction. It adapts the PageRank algorithm to rank sentences or words based on their importance within a text, using a graph where nodes represent text units and edges represent their semantic relationships.

Source Code Listing

code.ts

/**
 * Textrank Algorithm Implementation
 * @see Mihalcea, Rada, and Tarau, Paul. (2004). _TextRank: Bringing Order
 *     into Text._ Association for Computational Linguistics.
 * @see https://web.eecs.umich.edu/~mihalcea/papers/mihalcea.emnlp04.pdf
 */

import { Vert } from "./vert";
import { type Weights, type Ngram, type GraphData } from "./types";

export abstract class ExtractionStrategy {
  constructor() {}
  abstract transform(ngram: Ngram): string;
  abstract collect(graphData: GraphData): void;
}

export const DAMPEN = 0.85;

/**
 * Implementation of the Textrank algorithm. The strategy pattern is
 * used to direct whether we are ranking keywords/co-references or
 * sentences/similarity.
 */
export class Textrank {
  constructor(public extractionStrategy: ExtractionStrategy) {}

  extractData(ngrams: Ngram[]): Vert[] {
    const graphData = this.createGraph(ngrams);
    const { vertices, adj } = graphData;
    this.extractionStrategy.collect(graphData);
    this.scoreVertices(vertices, adj);
    return this.sortResult(vertices);
  }

  createGraph(ngrams: Ngram[]) {
    const corpus = new Map<string, number>();
    const vertices: Vert[] = []; // Indices are ids
    const tokenVec: number[] = [];

    for (const ngram of ngrams) {
      const text = this.extractionStrategy.transform(ngram);
      let id = corpus.get(text);
      if (id === undefined) {
        id = corpus.size;
        corpus.set(text, id);
        const vert = new Vert(id, text);
        vert.ngram = ngram;
        vertices.push(vert);
      }

      tokenVec.push(id);
    }

    // Initialize a sparse adjacency matrix - stores the number of occurances of an edge.
    const adj: number[][] = [];
    for (let i = 0; i < corpus.size; i++) adj[i] = [];
    return { tokenVec, vertices, adj };
  }

  scoreVertices(vertices: Vert[], adj: number[][]) {
    const threshhold = 0.0001;
    const maxIterations = 100;

    let converged = 0;
    let iterations = 0;

    while (converged < vertices.length && iterations < maxIterations) {
      converged = 0;
      iterations++;
      const scores = Array.from({ length: vertices.length });
      for (const [i, vi] of vertices.entries()) {
        const previous = vi.score;
        scores[i] = this.weightedScore(adj, vi);
        const errorRate = (scores[i] as number) - previous;
        if (errorRate <= threshhold) {
          converged++;
        }
      }

      // Update all scores
      for (const [i, score] of scores.entries()) vertices[i].score = score as number;
    }
  }

  sortResult(vertices: Vert[]) {
    return vertices.toSorted((a, b) => b.score - a.score);
  }

  weightedScore(weights: Weights, vert: Vert): number {
    let sum = 0;
    for (const inbound of vert.inbound) {
      let denom = 0;
      for (const outbound of inbound.outbound) {
        denom += weights[inbound.id][outbound.id];
      }

      // Skip if there are no outbound vertices, avoiding division by zero
      if (denom === 0) continue;
      sum += (weights[inbound.id][vert.id] / denom) * inbound.score;
    }

    // Return the score, but wait until after convergence to set it.
    const score = 1 - DAMPEN + DAMPEN * sum;
    return score;
  }
}

code_test.ts

types.ts

import { Vert } from "./vert";

export type Term = {
  text: string;
  pre: string;
  post: string;
  normal: string;
  tags?: Set<string>;
  index?: [n?: number, start?: number];
  id?: string;
  chunk?: string;
  dirty?: boolean;
  syllables?: string[];
  root?: string;
};

/**
 * An ordered list of keywords with scores.
 */
export type Scores = [string, number][];

export type Weights = number[][];

export type Ngram = Term[];

export type GraphData = {
  vertices: Vert[]; // indices are ids
  tokenVec: number[];
  adj: number[][];
};

vert.ts

import type {Scores, Term} from "./types";

export class Vert {
  constructor(
    public id: number,
    public val: string,
    public inbound: Vert[] = [],
    public outbound: Vert[] = [],
    public score = 1,
    public ngram: Term[] = [],
    public weights: Scores = [],
  ) {}
}
Tags: