-- | Greedy Heuristic for solving the MCSP problem.
module MCSP.Heuristics.Greedy (
    greedy,
) where

import Control.Applicative (pure)
import Data.Bool (otherwise)
import Data.Eq (Eq (..))
import Data.Function (on, (.))
import Data.Functor ((<$>))
import Data.Int (Int)
import Data.List (map)
import Data.Maybe (Maybe (..), maybe)
import Data.Ord (Ord (..))
import GHC.Err (errorWithoutStackTrace)
import GHC.Num ((+))
import Text.Show (Show)

import Data.IntMap.Strict (
    IntMap,
    delete,
    empty,
    foldlWithKey',
    insert,
    singleton,
    toAscList,
    union,
 )

import MCSP.Data.Meta (Meta)
import MCSP.Data.Pair (Pair, both, dupe, first, liftP, snd, transpose, uncurry)
import MCSP.Data.String (String (..), length)
import MCSP.Data.String.Extra (Partition, longestCommonSubstring, stripInfix)

-- | The pair @(idx, substr)@ where @idx@ is the index where @substr@ was taken from in the
-- original string.
type IndexedString a = (Int, String a)

-- | A collection of subtrings of the same original string, indexed by their original position.
--
-- Represents a partition of the original strings, but their relative order is maintained with the
-- indexes, not by their position a list.
type IndexedPartition a = IntMap (String a)

-- --------------------------------------- --
-- Longest Common Substring for Partitions --

-- | The result of @lcsPair@, holding the longest common subtring of a @IndexedPartitionPair@ and
-- the pair where such substring was found.
data LCSResult a = Result
    { -- | The string from the left partition from which the LCS was taken, and its index.
      forall a. LCSResult a -> (Int, String a)
left :: {-# UNPACK #-} !(Int, String a),
      -- | The longest common subtring of the partitions. Also, the LCS for @left@ and @right@.
      forall a. LCSResult a -> String a
lcs :: {-# UNPACK #-} !(String a),
      -- | The string from the right partition from which the LCS was taken, and its index.
      forall a. LCSResult a -> (Int, String a)
right :: {-# UNPACK #-} !(Int, String a)
    }
    deriving stock (Int -> LCSResult a -> ShowS
[LCSResult a] -> ShowS
LCSResult a -> String
(Int -> LCSResult a -> ShowS)
-> (LCSResult a -> String)
-> ([LCSResult a] -> ShowS)
-> Show (LCSResult a)
forall a. ShowString a => Int -> LCSResult a -> ShowS
forall a. ShowString a => [LCSResult a] -> ShowS
forall a. ShowString a => LCSResult a -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: forall a. ShowString a => Int -> LCSResult a -> ShowS
showsPrec :: Int -> LCSResult a -> ShowS
$cshow :: forall a. ShowString a => LCSResult a -> String
show :: LCSResult a -> String
$cshowList :: forall a. ShowString a => [LCSResult a] -> ShowS
showList :: [LCSResult a] -> ShowS
Show)

-- | The comparison key for ordering @LCSResults@.
--
-- For two candidate results, we always prefer the one with the longest @lcs@. If that is equal for
-- both, we take the one with the shortest original strings @left@ and @right@, hoping that it will
-- leave less partitions after removing @lcs@. Otherwise, we take the result with smallest indices.
cmpKey :: LCSResult a -> (Int, Int, Int)
cmpKey :: forall a. LCSResult a -> (Int, Int, Int)
cmpKey Result {left :: forall a. LCSResult a -> (Int, String a)
left = (Int
ln, String a
ls), lcs :: forall a. LCSResult a -> String a
lcs = String a
common, right :: forall a. LCSResult a -> (Int, String a)
right = (Int
rn, String a
rs)} =
    ( String a -> Int
forall a. String a -> Int
length String a
common,
      -(String a -> Int
forall a. String a -> Int
length String a
ls Int -> Int -> Int
forall a. Num a => a -> a -> a
+ String a -> Int
forall a. String a -> Int
length String a
rs),
      -(Int
ln Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
rn)
    )

instance Eq (LCSResult a) where
    == :: LCSResult a -> LCSResult a -> Bool
(==) = (Int, Int, Int) -> (Int, Int, Int) -> Bool
forall a. Eq a => a -> a -> Bool
(==) ((Int, Int, Int) -> (Int, Int, Int) -> Bool)
-> (LCSResult a -> (Int, Int, Int))
-> LCSResult a
-> LCSResult a
-> Bool
forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` LCSResult a -> (Int, Int, Int)
forall a. LCSResult a -> (Int, Int, Int)
cmpKey
    /= :: LCSResult a -> LCSResult a -> Bool
(/=) = (Int, Int, Int) -> (Int, Int, Int) -> Bool
forall a. Eq a => a -> a -> Bool
(/=) ((Int, Int, Int) -> (Int, Int, Int) -> Bool)
-> (LCSResult a -> (Int, Int, Int))
-> LCSResult a
-> LCSResult a
-> Bool
forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` LCSResult a -> (Int, Int, Int)
forall a. LCSResult a -> (Int, Int, Int)
cmpKey

instance Ord (LCSResult a) where
    compare :: LCSResult a -> LCSResult a -> Ordering
compare = (Int, Int, Int) -> (Int, Int, Int) -> Ordering
forall a. Ord a => a -> a -> Ordering
compare ((Int, Int, Int) -> (Int, Int, Int) -> Ordering)
-> (LCSResult a -> (Int, Int, Int))
-> LCSResult a
-> LCSResult a
-> Ordering
forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` LCSResult a -> (Int, Int, Int)
forall a. LCSResult a -> (Int, Int, Int)
cmpKey
    <= :: LCSResult a -> LCSResult a -> Bool
(<=) = (Int, Int, Int) -> (Int, Int, Int) -> Bool
forall a. Ord a => a -> a -> Bool
(<=) ((Int, Int, Int) -> (Int, Int, Int) -> Bool)
-> (LCSResult a -> (Int, Int, Int))
-> LCSResult a
-> LCSResult a
-> Bool
forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` LCSResult a -> (Int, Int, Int)
forall a. LCSResult a -> (Int, Int, Int)
cmpKey
    < :: LCSResult a -> LCSResult a -> Bool
(<) = (Int, Int, Int) -> (Int, Int, Int) -> Bool
forall a. Ord a => a -> a -> Bool
(<) ((Int, Int, Int) -> (Int, Int, Int) -> Bool)
-> (LCSResult a -> (Int, Int, Int))
-> LCSResult a
-> LCSResult a
-> Bool
forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` LCSResult a -> (Int, Int, Int)
forall a. LCSResult a -> (Int, Int, Int)
cmpKey
    > :: LCSResult a -> LCSResult a -> Bool
(>) = (Int, Int, Int) -> (Int, Int, Int) -> Bool
forall a. Ord a => a -> a -> Bool
(>) ((Int, Int, Int) -> (Int, Int, Int) -> Bool)
-> (LCSResult a -> (Int, Int, Int))
-> LCSResult a
-> LCSResult a
-> Bool
forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` LCSResult a -> (Int, Int, Int)
forall a. LCSResult a -> (Int, Int, Int)
cmpKey
    >= :: LCSResult a -> LCSResult a -> Bool
(>=) = (Int, Int, Int) -> (Int, Int, Int) -> Bool
forall a. Ord a => a -> a -> Bool
(>=) ((Int, Int, Int) -> (Int, Int, Int) -> Bool)
-> (LCSResult a -> (Int, Int, Int))
-> LCSResult a
-> LCSResult a
-> Bool
forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` LCSResult a -> (Int, Int, Int)
forall a. LCSResult a -> (Int, Int, Int)
cmpKey

-- | Returns the longest common subtring of a @IndexedPartitionPair@ and the pair where such
-- substring was found. Returns `Nothing` if no common substring can be found.
lcsPair :: Ord a => Pair (IndexedPartition a) -> Maybe (LCSResult a)
lcsPair :: forall a. Ord a => Pair (IndexedPartition a) -> Maybe (LCSResult a)
lcsPair (IndexedPartition a
xs, IndexedPartition a
ys) = (Maybe (LCSResult a) -> Int -> String a -> Maybe (LCSResult a))
-> Maybe (LCSResult a) -> IndexedPartition a -> Maybe (LCSResult a)
forall a b. (a -> Int -> b -> a) -> a -> IntMap b -> a
foldlWithKey' (IndexedPartition a
-> Maybe (LCSResult a) -> Int -> String a -> Maybe (LCSResult a)
forall {a}.
Ord a =>
IntMap (String a)
-> Maybe (LCSResult a) -> Int -> String a -> Maybe (LCSResult a)
lcsPairWith IndexedPartition a
ys) Maybe (LCSResult a)
forall a. Maybe a
Nothing IndexedPartition a
xs
  where
    lcsPairWith :: IntMap (String a)
-> Maybe (LCSResult a) -> Int -> String a -> Maybe (LCSResult a)
lcsPairWith IntMap (String a)
rhs Maybe (LCSResult a)
res Int
n String a
x
        | Maybe (LCSResult a) -> String a -> Bool
forall {a} {a}. Maybe (LCSResult a) -> String a -> Bool
longerResult Maybe (LCSResult a)
res String a
x = Maybe (LCSResult a)
res
        | Bool
otherwise = (Maybe (LCSResult a) -> Int -> String a -> Maybe (LCSResult a))
-> Maybe (LCSResult a) -> IntMap (String a) -> Maybe (LCSResult a)
forall a b. (a -> Int -> b -> a) -> a -> IntMap b -> a
foldlWithKey' ((Int, String a)
-> Maybe (LCSResult a) -> Int -> String a -> Maybe (LCSResult a)
forall {a}.
Ord a =>
(Int, String a)
-> Maybe (LCSResult a) -> Int -> String a -> Maybe (LCSResult a)
maxLCS (Int
n, String a
x)) Maybe (LCSResult a)
res IntMap (String a)
rhs
    maxLCS :: (Int, String a)
-> Maybe (LCSResult a) -> Int -> String a -> Maybe (LCSResult a)
maxLCS (Int, String a)
l Maybe (LCSResult a)
res Int
n String a
y
        | Maybe (LCSResult a) -> String a -> Bool
forall {a} {a}. Maybe (LCSResult a) -> String a -> Bool
longerResult Maybe (LCSResult a)
res String a
y = Maybe (LCSResult a)
res
        | Bool
otherwise = Maybe (LCSResult a) -> Maybe (LCSResult a) -> Maybe (LCSResult a)
forall a. Ord a => a -> a -> a
max Maybe (LCSResult a)
res ((Int, String a) -> (Int, String a) -> Maybe (LCSResult a)
forall {a}.
Ord a =>
(Int, String a) -> (Int, String a) -> Maybe (LCSResult a)
withLCS (Int, String a)
l (Int
n, String a
y))
    longerResult :: Maybe (LCSResult a) -> String a -> Bool
longerResult Maybe (LCSResult a)
res String a
s = Int -> (LCSResult a -> Int) -> Maybe (LCSResult a) -> Int
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Int
0 (String a -> Int
forall a. String a -> Int
length (String a -> Int)
-> (LCSResult a -> String a) -> LCSResult a -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. LCSResult a -> String a
forall a. LCSResult a -> String a
lcs) Maybe (LCSResult a)
res Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> String a -> Int
forall a. String a -> Int
length String a
s
    withLCS :: (Int, String a) -> (Int, String a) -> Maybe (LCSResult a)
withLCS l :: (Int, String a)
l@(Int
_, String a
x) r :: (Int, String a)
r@(Int
_, String a
y) = do
        String a
sub <- String a -> String a -> Maybe (String a)
forall a. Ord a => String a -> String a -> Maybe (String a)
longestCommonSubstring String a
x String a
y
        pure (Result {left :: (Int, String a)
left = (Int, String a)
l, lcs :: String a
lcs = String a
sub, right :: (Int, String a)
right = (Int, String a)
r})

-- | Break the indexed string from a partition removing the @lcs@ from it.
--
-- Replace the string with the results from @`stripInfix` lcs s@, returning the @IndexedString@ for
-- the @lcs@, which should be collected in another @IndexedPartition@.
breakAt ::
    Eq a =>
    String a
    -> IndexedString a
    -> IndexedPartition a
    -> (IndexedString a, IndexedPartition a)
breakAt :: forall a.
Eq a =>
String a
-> IndexedString a
-> IndexedPartition a
-> (IndexedString a, IndexedPartition a)
breakAt String a
s (Int
n, String a
v) IndexedPartition a
m = case String a -> String a -> Maybe (String a, String a)
forall a.
Eq a =>
String a -> String a -> Maybe (String a, String a)
stripInfix String a
s String a
v of
    Just (String a
prefix, String a
suffix) -> String a
-> String a
-> (Int, IndexedPartition a)
-> ((Int, String a), IndexedPartition a)
insertItems String a
prefix String a
suffix (Int
n, Int -> IndexedPartition a -> IndexedPartition a
forall a. Int -> IntMap a -> IntMap a
delete Int
n IndexedPartition a
m)
    Maybe (String a, String a)
Nothing -> String -> ((Int, String a), IndexedPartition a)
forall a. String -> a
errorWithoutStackTrace String
"greedy: given LCS was not part of the input string."
  where
    -- insert each item, updating the indices if needed
    insertItems :: String a
-> String a
-> (Int, IndexedPartition a)
-> ((Int, String a), IndexedPartition a)
insertItems String a
s1 String a
s2 = (Int -> (Int, String a))
-> (Int, IndexedPartition a)
-> ((Int, String a), IndexedPartition a)
forall a a' b. (a -> a') -> (a, b) -> (a', b)
first (,String a
s) ((Int, IndexedPartition a)
 -> ((Int, String a), IndexedPartition a))
-> ((Int, IndexedPartition a) -> (Int, IndexedPartition a))
-> (Int, IndexedPartition a)
-> ((Int, String a), IndexedPartition a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String a -> (Int, IndexedPartition a) -> (Int, IndexedPartition a)
insert2 String a
s2 ((Int, IndexedPartition a) -> (Int, IndexedPartition a))
-> ((Int, IndexedPartition a) -> (Int, IndexedPartition a))
-> (Int, IndexedPartition a)
-> (Int, IndexedPartition a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String a -> (Int, IndexedPartition a) -> (Int, IndexedPartition a)
forall {a}.
String a -> (Int, IntMap (String a)) -> (Int, IntMap (String a))
insert1 String a
s1
    insert1 :: String a -> (Int, IntMap (String a)) -> (Int, IntMap (String a))
insert1 String a
Null (Int, IntMap (String a))
pp = (Int, IntMap (String a))
pp
    insert1 String a
s1 (Int
i, IntMap (String a)
p) = (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ String a -> Int
forall a. String a -> Int
length String a
s1, Int -> String a -> IntMap (String a) -> IntMap (String a)
forall a. Int -> a -> IntMap a -> IntMap a
insert Int
i String a
s1 IntMap (String a)
p)
    insert2 :: String a -> (Int, IndexedPartition a) -> (Int, IndexedPartition a)
insert2 String a
Null (Int, IndexedPartition a)
pp = (Int, IndexedPartition a)
pp
    insert2 String a
s2 (Int
i, IndexedPartition a
p) = (Int
i, Int -> String a -> IndexedPartition a -> IndexedPartition a
forall a. Int -> a -> IntMap a -> IntMap a
insert (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ String a -> Int
forall a. String a -> Int
length String a
s) String a
s2 IndexedPartition a
p)

-- | Find the longest common substring, remove it from the partitions and returns it with the
-- indices where to reinsert it for each partition.
extractLCS ::
    Ord a =>
    Pair (IndexedPartition a)
    -> Maybe (Pair (IndexedString a), Pair (IndexedPartition a))
extractLCS :: forall a.
Ord a =>
Pair (IndexedPartition a)
-> Maybe (Pair (IndexedString a), Pair (IndexedPartition a))
extractLCS Pair (IndexedPartition a)
parts = LCSResult a
-> ((IndexedString a, IndexedString a), Pair (IndexedPartition a))
breakEach (LCSResult a
 -> ((IndexedString a, IndexedString a), Pair (IndexedPartition a)))
-> Maybe (LCSResult a)
-> Maybe
     ((IndexedString a, IndexedString a), Pair (IndexedPartition a))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Pair (IndexedPartition a) -> Maybe (LCSResult a)
forall a. Ord a => Pair (IndexedPartition a) -> Maybe (LCSResult a)
lcsPair Pair (IndexedPartition a)
parts
  where
    breakEach :: LCSResult a
-> ((IndexedString a, IndexedString a), Pair (IndexedPartition a))
breakEach Result {IndexedString a
String a
left :: forall a. LCSResult a -> (Int, String a)
lcs :: forall a. LCSResult a -> String a
right :: forall a. LCSResult a -> (Int, String a)
left :: IndexedString a
lcs :: String a
right :: IndexedString a
..} = ((IndexedString a, IndexedPartition a),
 (IndexedString a, IndexedPartition a))
-> ((IndexedString a, IndexedString a), Pair (IndexedPartition a))
forall a b c d. ((a, b), (c, d)) -> ((a, c), (b, d))
transpose ((IndexedString a
 -> IndexedPartition a -> (IndexedString a, IndexedPartition a))
-> (IndexedString a, IndexedString a)
-> Pair (IndexedPartition a)
-> ((IndexedString a, IndexedPartition a),
    (IndexedString a, IndexedPartition a))
forall a b c. (a -> b -> c) -> Pair a -> Pair b -> Pair c
liftP (String a
-> IndexedString a
-> IndexedPartition a
-> (IndexedString a, IndexedPartition a)
forall a.
Eq a =>
String a
-> IndexedString a
-> IndexedPartition a
-> (IndexedString a, IndexedPartition a)
breakAt String a
lcs) (IndexedString a
left, IndexedString a
right) Pair (IndexedPartition a)
parts)

-- | Recursively run the greedy algorithm by finding the longest common substring, breaking the
-- matched subtrings and collecting the results in a two new partitions. When no common substring
-- is found, the algorithm is finished, and the result partition is merged with the remaining
-- unbroken strings.
indexedGreedy :: Ord a => Pair (IndexedPartition a) -> Pair (IndexedPartition a)
indexedGreedy :: forall a.
Ord a =>
Pair (IndexedPartition a) -> Pair (IndexedPartition a)
indexedGreedy = Pair (IntMap (String a))
-> Pair (IntMap (String a)) -> Pair (IntMap (String a))
forall {a}.
Ord a =>
Pair (IntMap (String a))
-> Pair (IntMap (String a)) -> Pair (IntMap (String a))
go (IntMap (String a) -> Pair (IntMap (String a))
forall a. a -> (a, a)
dupe IntMap (String a)
forall a. IntMap a
empty)
  where
    go :: Pair (IntMap (String a))
-> Pair (IntMap (String a)) -> Pair (IntMap (String a))
go Pair (IntMap (String a))
pi Pair (IntMap (String a))
pp = case Pair (IntMap (String a))
-> Maybe (Pair (IndexedString a), Pair (IntMap (String a)))
forall a.
Ord a =>
Pair (IndexedPartition a)
-> Maybe (Pair (IndexedString a), Pair (IndexedPartition a))
extractLCS Pair (IntMap (String a))
pp of
        Just (Pair (IndexedString a)
xy, Pair (IntMap (String a))
pp') -> Pair (IntMap (String a))
-> Pair (IntMap (String a)) -> Pair (IntMap (String a))
go ((IndexedString a -> IntMap (String a) -> IntMap (String a))
-> Pair (IndexedString a)
-> Pair (IntMap (String a))
-> Pair (IntMap (String a))
forall a b c. (a -> b -> c) -> Pair a -> Pair b -> Pair c
liftP IndexedString a -> IntMap (String a) -> IntMap (String a)
forall {b}. (Int, b) -> IntMap b -> IntMap b
add Pair (IndexedString a)
xy Pair (IntMap (String a))
pi) Pair (IntMap (String a))
pp'
        Maybe (Pair (IndexedString a), Pair (IntMap (String a)))
Nothing -> (IntMap (String a) -> IntMap (String a) -> IntMap (String a))
-> Pair (IntMap (String a))
-> Pair (IntMap (String a))
-> Pair (IntMap (String a))
forall a b c. (a -> b -> c) -> Pair a -> Pair b -> Pair c
liftP IntMap (String a) -> IntMap (String a) -> IntMap (String a)
forall a. IntMap a -> IntMap a -> IntMap a
union Pair (IntMap (String a))
pi Pair (IntMap (String a))
pp
    add :: (Int, b) -> IntMap b -> IntMap b
add = (Int -> b -> IntMap b -> IntMap b)
-> (Int, b) -> IntMap b -> IntMap b
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry Int -> b -> IntMap b -> IntMap b
forall a. Int -> a -> IntMap a -> IntMap a
insert

-- | MCSP greedy algorithm.
--
-- Tries to solve the MCSP by repeatedly finding the longest common substring (LCS), breaking the
-- strings with it, and inserting the LCS in the resulting partition, until no common substring is
-- left.
greedy :: Ord a => Pair (String a) -> Meta (Pair (Partition a))
greedy :: forall a. Ord a => Pair (String a) -> Meta (Pair (Partition a))
greedy = Pair (Partition a) -> Meta (Pair (Partition a))
forall a. a -> Meta a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Pair (Partition a) -> Meta (Pair (Partition a)))
-> (Pair (String a) -> Pair (Partition a))
-> Pair (String a)
-> Meta (Pair (Partition a))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pair (String a) -> Pair (Partition a)
solve
  where
    sort :: IntMap b -> [b]
sort = ((Int, b) -> b) -> [(Int, b)] -> [b]
forall a b. (a -> b) -> [a] -> [b]
map (Int, b) -> b
forall a b. (a, b) -> b
snd ([(Int, b)] -> [b]) -> (IntMap b -> [(Int, b)]) -> IntMap b -> [b]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IntMap b -> [(Int, b)]
forall a. IntMap a -> [(Int, a)]
toAscList
    solve :: Pair (String a) -> Pair (Partition a)
solve = (IntMap (String a) -> Partition a)
-> (IntMap (String a), IntMap (String a)) -> Pair (Partition a)
forall a b. (a -> b) -> (a, a) -> (b, b)
both IntMap (String a) -> Partition a
forall {b}. IntMap b -> [b]
sort ((IntMap (String a), IntMap (String a)) -> Pair (Partition a))
-> (Pair (String a) -> (IntMap (String a), IntMap (String a)))
-> Pair (String a)
-> Pair (Partition a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (IntMap (String a), IntMap (String a))
-> (IntMap (String a), IntMap (String a))
forall a.
Ord a =>
Pair (IndexedPartition a) -> Pair (IndexedPartition a)
indexedGreedy ((IntMap (String a), IntMap (String a))
 -> (IntMap (String a), IntMap (String a)))
-> (Pair (String a) -> (IntMap (String a), IntMap (String a)))
-> Pair (String a)
-> (IntMap (String a), IntMap (String a))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (String a -> IntMap (String a))
-> Pair (String a) -> (IntMap (String a), IntMap (String a))
forall a b. (a -> b) -> (a, a) -> (b, b)
both (Int -> String a -> IntMap (String a)
forall a. Int -> a -> IntMap a
singleton Int
0)