module Equivalence.EquivalenceClasses
  ( EquivalenceClass
  , Partition
  , discretePartition
  , mergePairs
  , mergeSets
  , getElem
  , getTail
  , tsort
  ) where
import           Data.Foldable (find, foldl')
import           Data.Set      (Set)
import qualified Data.Set      as Set
type EquivalenceClass a = Set a
type Partition a = Set (EquivalenceClass a)
discretePartition :: (Ord a) => [a] -> Partition a
discretePartition = Set.fromList . map Set.singleton
mergePairs :: (Ord a, Show a) => [(a, a)] -> Partition a -> Partition a
mergePairs toBeGlued partition = foldl' (flip merge) partition toBeGlued
  where
    merge (e1, e2) s = mergeEquivalences (e1, e2) s `Set.union` (s `diff` (e1,e2))
    diff s (e1, e2) =
      if e1 == e2 then
        Set.delete (findEquivalenceClass e1 s) s
      else
        Set.delete (findEquivalenceClass e1 s) . Set.delete (findEquivalenceClass e2 s) $ s
mergeSets :: (Ord a, Show a) => [Set a] -> Partition a -> Partition a
mergeSets toBeGlued partition = foldl' (flip merge) partition toBeGlued
  where
    merge eq s = mergeNEquivalences eq s `Set.union` diffNEquivalences eq s
diffNEquivalences :: (Ord a, Show a) => Set a -> Partition a -> Partition a
diffNEquivalences eq set = actualDiff allSubSets
  where
    actualDiff = Set.foldl Set.difference set
    allSubSets = Set.map newFind eq
    newFind = Set.singleton . (`findEquivalenceClass` set)
mergeNEquivalences :: (Ord a, Show a) => Set a -> Partition a -> Partition a
mergeNEquivalences eq set = Set.singleton $ actualMerge allSubSets
  where
    actualMerge = Set.foldl Set.union Set.empty
    allSubSets = Set.map (`findEquivalenceClass` set) eq
getElem :: Set a -> a
getElem = Set.elemAt 0
getUnitSubset :: Set a -> Set a
getUnitSubset set = Set.singleton (getElem set)
getTail :: (Ord a) => Set a -> Set a
getTail set = set `Set.difference` getUnitSubset set
mergeEquivalences :: (Ord a, Show a) => (a, a) -> Partition a -> Partition a
mergeEquivalences (e1,e2) set = Set.singleton (findEquivalenceClass e1 set `Set.union` findEquivalenceClass e2 set)
findEquivalenceClass :: (Eq a, Show a) => a -> Partition a -> EquivalenceClass a
findEquivalenceClass element set
  | Set.null domain = error $ "could not find equivalence class for " ++ show element ++ " in " ++ show set
  | otherwise = getElem domain
  where
    domain = Set.filter (element `elem`) set
type Relation a = Set (a,a)
elementInImage :: Ord a => Relation a -> a -> Relation a
elementInImage rel item = Set.filter ((== item) . snd) rel
elementNotInDomain :: Ord a => Relation a -> a -> Relation a
elementNotInDomain rel item = Set.filter ((/= item) . fst) rel
noIncoming :: Ord a => Relation a -> Set a -> Maybe a
noIncoming rel = find (null . elementInImage rel)
isCyclic :: Ord a => Relation a -> Bool
isCyclic = not . null . until (\rel -> removeOneItem rel == rel) removeOneItem
  where
    removeOneItem rel = maybe rel (elementNotInDomain rel) . noIncoming rel $ relationDomain rel
    relationDomain = Set.map fst
tsort :: Ord a => Relation a -> Set a -> Maybe [a]
tsort rel disconnected =
  let
    items = relationElements rel `Set.union` disconnected
  in if isCyclic rel then Nothing
     else Just $ buildOrdering rel items
relationElements :: Ord a => Relation a -> Set a
relationElements = foldr (\(x,y) -> Set.insert x . Set.insert y) Set.empty
buildOrdering :: Ord a => Relation a -> Set a -> [a]
buildOrdering relation items = maybe [] addToOrderRemoveFromRelation $ noIncoming relation items
  where
    addToOrderRemoveFromRelation i = i : buildOrdering (elementNotInDomain relation i) (Set.delete i items)