levenshtein distance


See On Github

Data

Contributor

Generic placeholder thumbnail

by Yonaba

in lua

Tags

Source Code

-- Levenshtein distance implementation
-- See: http://en.wikipedia.org/wiki/Levenshtein_distance

-- Iterative matrix-based method
-- See: http://en.wikipedia.org/wiki/Levenshtein_distance#Iterative_with_full_matrix

-- Return the minimum of three elements
local function min(a, b, c)
	return math.min(math.min(a, b), c)
end

-- Creates a 2D matrix
local function matrix(row,col)
  local m = {}
  for i = 1,row do m[i] = {}
    for j = 1,col do m[i][j] = 0 end
  end
  return m
end

-- Calculates the Levenshtein distance between two strings
local function lev_iter_based(strA,strB)
  local M = matrix(#strA+1,#strB+1)
  local i,j,cost
  local row,col = #M,#M[1]
  for i = 1,row do M[i][1] = i-1 end
  for j = 1,col do M[1][j] = j-1 end
  for i = 2,row do
    for j = 2,col do
      if (strA:sub(i-1,i-1) == strB:sub(j-1,j-1)) then cost = 0
      else cost = 1
      end
    M[i][j] = min(M[i-1][j]+1,M[i][j-1]+1,M[i-1][j-1]+cost)
    end
  end
  return M[row][col]
end

-- Recursive method
-- See: http://en.wikipedia.org/wiki/Levenshtein_distance#Recursive

-- Calculates the Levenshtein distance between two strings
local function lev_recursive_based(strA, strB, s, t)
  s, t = s or #strA, t or #strB
  if s == 0 then return t end
  if t == 0 then return s end
  local cost = strA:sub(s,s) == strB:sub(t,t) and 0 or 1
  return min(
    lev_recursive_based(strA, strB, s - 1, t) + 1,
    lev_recursive_based(strA, strB, s, t - 1) + 1,
    lev_recursive_based(strA, strB, s - 1, t - 1) + cost
  )
end

return {
	lev_iter = lev_iter_based,
	lev_recursive = function(strA, strB) -- Wrapped to shadow access to s and t args
		return lev_recursive_based(strA, strB)
	end
}
-- Tests for levenshtein.lua
local lev_iter      = (require 'levenshtein').lev_iter
local lev_recursive = (require 'levenshtein').lev_recursive

local total, pass = 0, 0

local function dec(str, len)
	return #str < len
	   and str .. (('.'):rep(len-#str))
		  or str:sub(1,len)
end

local function run(message, f)
	total = total + 1
	local ok, err = pcall(f)
	if ok then pass = pass + 1 end
	local status = ok and 'PASSED' or 'FAILED'
	print(('%02d. %68s: %s'):format(total, dec(message,68), status))
end

run('Fails on running with no arg', function()
	assert(not pcall(lev_iter))
	assert(not pcall(lev_recursive))
end)

run('Fails if only one string is passed', function()
	assert(not pcall(     lev_iter, 'astring'))
	assert(not pcall(lev_recursive, 'astring'))
end)

run('Otherwise, returns the levenshtein distance', function()
	assert(lev_iter('Godfather', 'Godfather') == 0)
	assert(lev_iter('Godfather',  'Godfathe') == 1)
	assert(lev_iter('Godfather',  'odfather') == 1)
	assert(lev_iter('Godfather',    'Gdfthr') == 3)
	assert(lev_iter(    'seven',     'eight') == 5)

	assert(lev_recursive('Godfather', 'Godfather') == 0)
	assert(lev_recursive('Godfather',  'Godfathe') == 1)
	assert(lev_recursive('Godfather',  'odfather') == 1)
	assert(lev_recursive('Godfather',    'Gdfthr') == 3)
	assert(lev_recursive(    'seven',     'eight') == 5)
end)

print(('-'):rep(80))
print(('Total : %02d: Pass: %02d - Failed : %02d - Success: %.2f %%')
	:format(total, pass, total-pass, (pass*100/total)))