use std::collections::btree_map::Entry;
use std::collections::{BTreeMap, BTreeSet};
pub fn topo_sort_scc<Id, NodesFn, NodeIds, PredsFn, SuccsFn, PredsIter, SuccsIter>(
    mut nodes_fn: NodesFn,
    mut preds_fn: PredsFn,
    succs_fn: SuccsFn,
) -> Vec<Id>
where
    Id: Copy + Eq + Ord,
    NodesFn: FnMut() -> NodeIds,
    NodeIds: IntoIterator<Item = Id>,
    PredsFn: FnMut(Id) -> PredsIter,
    SuccsFn: FnMut(Id) -> SuccsIter,
    PredsIter: IntoIterator<Item = Id>,
    SuccsIter: IntoIterator<Item = Id>,
{
    let scc = scc_kosaraju((nodes_fn)(), &mut preds_fn, succs_fn);
    let topo_sort_order = {
        let mut condensed_preds: BTreeMap<Id, Vec<Id>> = Default::default();
        for v in (nodes_fn)() {
            let v = scc[&v];
            condensed_preds.entry(v).or_default().extend(
                (preds_fn)(v)
                    .into_iter()
                    .map(|u| scc[&u])
                    .filter(|&u| v != u),
            );
        }
        topo_sort((nodes_fn)(), |v| {
            condensed_preds.get(&v).into_iter().flatten().cloned()
        })
        .map_err(drop)
        .expect("No cycles after SCC condensing.")
    };
    topo_sort_order
}
pub fn topo_sort<Id, NodeIds, PredsFn, PredsIter>(
    node_ids: NodeIds,
    mut preds_fn: PredsFn,
) -> Result<Vec<Id>, Vec<Id>>
where
    Id: Copy + Eq + Ord,
    NodeIds: IntoIterator<Item = Id>,
    PredsFn: FnMut(Id) -> PredsIter,
    PredsIter: IntoIterator<Item = Id>,
{
    let (mut marked, mut order) = Default::default();
    fn pred_dfs_postorder<Id, PredsFn, PredsIter>(
        node_id: Id,
        preds_fn: &mut PredsFn,
        marked: &mut BTreeMap<Id, bool>, order: &mut Vec<Id>,
    ) -> Result<(), ()>
    where
        Id: Copy + Eq + Ord,
        PredsFn: FnMut(Id) -> PredsIter,
        PredsIter: IntoIterator<Item = Id>,
    {
        match marked.get(&node_id) {
            Some(_permanent @ true) => Ok(()),
            Some(_temporary @ false) => {
                order.clear();
                order.push(node_id);
                Err(())
            }
            None => {
                marked.insert(node_id, false);
                for next_pred in (preds_fn)(node_id) {
                    pred_dfs_postorder(next_pred, preds_fn, marked, order).map_err(|()| {
                        if order.len() == 1 || order.first().unwrap() != order.last().unwrap() {
                            order.push(node_id);
                        }
                    })?;
                }
                order.push(node_id);
                marked.insert(node_id, true);
                Ok(())
            }
        }
    }
    for node_id in node_ids {
        if pred_dfs_postorder(node_id, &mut preds_fn, &mut marked, &mut order).is_err() {
            let end = order.last().unwrap();
            let beg = order.iter().position(|n| n == end).unwrap();
            order.drain(0..=beg);
            return Err(order);
        }
    }
    Ok(order)
}
pub fn scc_kosaraju<Id, NodeIds, PredsFn, SuccsFn, PredsIter, SuccsIter>(
    nodes: NodeIds,
    mut preds_fn: PredsFn,
    mut succs_fn: SuccsFn,
) -> BTreeMap<Id, Id>
where
    Id: Copy + Eq + Ord,
    NodeIds: IntoIterator<Item = Id>,
    PredsFn: FnMut(Id) -> PredsIter,
    SuccsFn: FnMut(Id) -> SuccsIter,
    PredsIter: IntoIterator<Item = Id>,
    SuccsIter: IntoIterator<Item = Id>,
{
    fn visit<Id, SuccsFn, SuccsIter>(
        succs_fn: &mut SuccsFn,
        u: Id,
        seen: &mut BTreeSet<Id>,
        stack: &mut Vec<Id>,
    ) where
        Id: Copy + Eq + Ord,
        SuccsFn: FnMut(Id) -> SuccsIter,
        SuccsIter: IntoIterator<Item = Id>,
    {
        if seen.insert(u) {
            for v in (succs_fn)(u) {
                visit(succs_fn, v, seen, stack);
            }
            stack.push(u);
        }
    }
    let (mut seen, mut stack) = Default::default();
    for sg in nodes {
        visit(&mut succs_fn, sg, &mut seen, &mut stack);
    }
    let _ = seen;
    fn assign<Id, PredsFn, PredsIter>(
        preds_fn: &mut PredsFn,
        v: Id,
        root: Id,
        components: &mut BTreeMap<Id, Id>,
    ) where
        Id: Copy + Eq + Ord,
        PredsFn: FnMut(Id) -> PredsIter,
        PredsIter: IntoIterator<Item = Id>,
    {
        if let Entry::Vacant(vacant_entry) = components.entry(v) {
            vacant_entry.insert(root);
            for u in (preds_fn)(v) {
                assign(preds_fn, u, root, components);
            }
        }
    }
    let mut components = Default::default();
    for sg in stack.into_iter().rev() {
        assign(&mut preds_fn, sg, sg, &mut components);
    }
    components
}
#[cfg(test)]
mod test {
    use itertools::Itertools;
    use super::*;
    #[test]
    pub fn test_toposort() {
        let edges = [
            (5, 11),
            (11, 2),
            (11, 9),
            (11, 10),
            (7, 11),
            (7, 8),
            (8, 9),
            (3, 8),
            (3, 10),
        ];
        let sort = topo_sort([2, 3, 5, 7, 8, 9, 10, 11], |v| {
            edges
                .iter()
                .filter(move |&&(_, dst)| v == dst)
                .map(|&(src, _)| src)
        });
        assert!(
            sort.is_ok(),
            "Did not expect cycle: {:?}",
            sort.unwrap_err()
        );
        let sort = sort.unwrap();
        println!("{:?}", sort);
        let position: BTreeMap<_, _> = sort.iter().enumerate().map(|(i, &x)| (x, i)).collect();
        for (src, dst) in edges.iter() {
            assert!(position[src] < position[dst]);
        }
    }
    #[test]
    pub fn test_toposort_cycle() {
        let edges = [
            ('A', 'B'),
            ('B', 'C'),
            ('C', 'E'),
            ('D', 'B'),
            ('E', 'F'),
            ('E', 'D'),
        ];
        let ids = edges
            .iter()
            .flat_map(|&(a, b)| [a, b])
            .collect::<BTreeSet<_>>();
        let cycle_rotations = BTreeSet::from_iter([
            ['B', 'C', 'E', 'D'],
            ['C', 'E', 'D', 'B'],
            ['E', 'D', 'B', 'C'],
            ['D', 'B', 'C', 'E'],
        ]);
        let permutations = ids.iter().copied().permutations(ids.len());
        for permutation in permutations {
            let result = topo_sort(permutation.iter().copied(), |v| {
                edges
                    .iter()
                    .filter(move |&&(_, dst)| v == dst)
                    .map(|&(src, _)| src)
            });
            assert!(result.is_err());
            let cycle = result.unwrap_err();
            assert!(
                cycle_rotations.contains(&*cycle),
                "cycle: {:?}, vertex order: {:?}",
                cycle,
                permutation
            );
        }
    }
    #[test]
    pub fn test_scc_kosaraju() {
        let edges = [
            ('a', 'b'),
            ('b', 'c'),
            ('b', 'f'),
            ('b', 'e'),
            ('c', 'd'),
            ('c', 'g'),
            ('d', 'c'),
            ('d', 'h'),
            ('e', 'a'),
            ('e', 'f'),
            ('f', 'g'),
            ('g', 'f'),
            ('h', 'd'),
            ('h', 'g'),
        ];
        let scc = scc_kosaraju(
            'a'..='g',
            |v| {
                edges
                    .iter()
                    .filter(move |&&(_, dst)| v == dst)
                    .map(|&(src, _)| src)
            },
            |u| {
                edges
                    .iter()
                    .filter(move |&&(src, _)| u == src)
                    .map(|&(_, dst)| dst)
            },
        );
        assert_ne!(scc[&'a'], scc[&'c']);
        assert_ne!(scc[&'a'], scc[&'f']);
        assert_ne!(scc[&'c'], scc[&'f']);
        assert_eq!(scc[&'a'], scc[&'b']);
        assert_eq!(scc[&'a'], scc[&'e']);
        assert_eq!(scc[&'c'], scc[&'d']);
        assert_eq!(scc[&'c'], scc[&'h']);
        assert_eq!(scc[&'f'], scc[&'g']);
    }
}