use std::borrow::Cow;
use std::hash::Hash;
use proc_macro2::{Ident, Span, TokenStream};
use quote::ToTokens;
use serde::{Deserialize, Serialize};
use slotmap::new_key_type;
use syn::punctuated::Punctuated;
use syn::spanned::Spanned;
use syn::{Expr, ExprPath, GenericArgument, Token, Type};
use self::ops::{OperatorConstraints, Persistence};
use crate::diagnostic::{Diagnostic, Level};
use crate::parse::{HfCode, IndexInt, Operator, PortIndex, Ported};
use crate::pretty_span::PrettySpan;
mod di_mul_graph;
mod eliminate_extra_unions_tees;
mod flat_graph_builder;
mod flat_to_partitioned;
mod graph_write;
mod hydroflow_graph;
mod hydroflow_graph_debugging;
use std::fmt::Display;
pub use di_mul_graph::DiMulGraph;
pub use eliminate_extra_unions_tees::eliminate_extra_unions_tees;
pub use flat_graph_builder::FlatGraphBuilder;
pub use flat_to_partitioned::partition_graph;
pub use hydroflow_graph::{DfirGraph, WriteConfig, WriteGraphType};
pub mod graph_algorithms;
pub mod ops;
new_key_type! {
    pub struct GraphNodeId;
    pub struct GraphEdgeId;
    pub struct GraphSubgraphId;
    pub struct GraphLoopId;
}
const CONTEXT: &str = "context";
const HYDROFLOW: &str = "df";
const HANDOFF_NODE_STR: &str = "handoff";
const MODULE_BOUNDARY_NODE_STR: &str = "module_boundary";
mod serde_syn {
    use serde::{Deserialize, Deserializer, Serializer};
    pub fn serialize<S, T>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
    where
        S: Serializer,
        T: quote::ToTokens,
    {
        serializer.serialize_str(&value.to_token_stream().to_string())
    }
    pub fn deserialize<'de, D, T>(deserializer: D) -> Result<T, D::Error>
    where
        D: Deserializer<'de>,
        T: syn::parse::Parse,
    {
        let s = String::deserialize(deserializer)?;
        syn::parse_str(&s).map_err(<D::Error as serde::de::Error>::custom)
    }
}
#[derive(Clone, Debug, Serialize, Deserialize, PartialOrd, Ord, PartialEq, Eq)]
struct Varname(#[serde(with = "serde_syn")] pub Ident);
#[derive(Clone, Serialize, Deserialize)]
pub enum GraphNode {
    Operator(#[serde(with = "serde_syn")] Operator),
    Handoff {
        #[serde(skip, default = "Span::call_site")]
        src_span: Span,
        #[serde(skip, default = "Span::call_site")]
        dst_span: Span,
    },
    ModuleBoundary {
        input: bool,
        #[serde(skip, default = "Span::call_site")]
        import_expr: Span,
    },
}
impl GraphNode {
    pub fn to_pretty_string(&self) -> Cow<'static, str> {
        match self {
            GraphNode::Operator(op) => op.to_pretty_string().into(),
            GraphNode::Handoff { .. } => HANDOFF_NODE_STR.into(),
            GraphNode::ModuleBoundary { .. } => MODULE_BOUNDARY_NODE_STR.into(),
        }
    }
    pub fn to_name_string(&self) -> Cow<'static, str> {
        match self {
            GraphNode::Operator(op) => op.name_string().into(),
            GraphNode::Handoff { .. } => HANDOFF_NODE_STR.into(),
            GraphNode::ModuleBoundary { .. } => MODULE_BOUNDARY_NODE_STR.into(),
        }
    }
    pub fn span(&self) -> Span {
        match self {
            Self::Operator(op) => op.span(),
            &Self::Handoff { src_span, dst_span } => src_span.join(dst_span).unwrap_or(src_span),
            Self::ModuleBoundary { import_expr, .. } => *import_expr,
        }
    }
}
impl std::fmt::Debug for GraphNode {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        match self {
            Self::Operator(operator) => {
                write!(f, "Node::Operator({} span)", PrettySpan(operator.span()))
            }
            Self::Handoff { .. } => write!(f, "Node::Handoff"),
            Self::ModuleBoundary { input, .. } => {
                write!(f, "Node::ModuleBoundary{{input: {}}}", input)
            }
        }
    }
}
#[derive(Clone, Debug)]
pub struct OperatorInstance {
    pub op_constraints: &'static OperatorConstraints,
    pub input_ports: Vec<PortIndexValue>,
    pub output_ports: Vec<PortIndexValue>,
    pub singletons_referenced: Vec<Ident>,
    pub generics: OpInstGenerics,
    pub arguments_pre: Punctuated<Expr, Token![,]>,
    pub arguments_raw: TokenStream,
}
#[derive(Clone, Debug)]
pub struct OpInstGenerics {
    pub generic_args: Option<Punctuated<GenericArgument, Token![,]>>,
    pub persistence_args: Vec<Persistence>,
    pub type_args: Vec<Type>,
}
pub fn get_operator_generics(
    diagnostics: &mut Vec<Diagnostic>,
    operator: &Operator,
) -> OpInstGenerics {
    let generic_args = operator.type_arguments().cloned();
    let persistence_args = generic_args.iter().flatten().map_while(|generic_arg| match generic_arg {
            GenericArgument::Lifetime(lifetime) => {
                match &*lifetime.ident.to_string() {
                    "static" => Some(Persistence::Static),
                    "tick" => Some(Persistence::Tick),
                    "mutable" => Some(Persistence::Mutable),
                    _ => {
                        diagnostics.push(Diagnostic::spanned(
                            generic_arg.span(),
                            Level::Error,
                            format!("Unknown lifetime generic argument `'{}`, expected `'tick`, `'static`, or `'mutable`.", lifetime.ident),
                        ));
                        None
                    }
                }
            },
            _ => None,
        }).collect::<Vec<_>>();
    let type_args = generic_args
        .iter()
        .flatten()
        .skip(persistence_args.len())
        .map_while(|generic_arg| match generic_arg {
            GenericArgument::Type(typ) => Some(typ),
            _ => None,
        })
        .cloned()
        .collect::<Vec<_>>();
    OpInstGenerics {
        generic_args,
        persistence_args,
        type_args,
    }
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
pub enum Color {
    Pull,
    Push,
    Comp,
    Hoff,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub enum PortIndexValue {
    Int(#[serde(with = "serde_syn")] IndexInt),
    Path(#[serde(with = "serde_syn")] ExprPath),
    Elided(#[serde(skip)] Option<Span>),
}
impl PortIndexValue {
    pub fn from_ported<Inner>(ported: Ported<Inner>) -> (Self, Inner, Self)
    where
        Inner: Spanned,
    {
        let ported_span = Some(ported.inner.span());
        let port_inn = ported
            .inn
            .map(|idx| idx.index.into())
            .unwrap_or_else(|| Self::Elided(ported_span));
        let inner = ported.inner;
        let port_out = ported
            .out
            .map(|idx| idx.index.into())
            .unwrap_or_else(|| Self::Elided(ported_span));
        (port_inn, inner, port_out)
    }
    pub fn is_specified(&self) -> bool {
        !matches!(self, Self::Elided(_))
    }
    pub fn combine(self, other: Self) -> Result<Self, Self> {
        if self.is_specified() {
            if other.is_specified() {
                Err(self)
            } else {
                Ok(self)
            }
        } else {
            Ok(other)
        }
    }
    pub fn as_error_message_string(&self) -> String {
        match self {
            PortIndexValue::Int(n) => format!("`{}`", n.value),
            PortIndexValue::Path(path) => format!("`{}`", path.to_token_stream()),
            PortIndexValue::Elided(_) => "<elided>".to_owned(),
        }
    }
    pub fn span(&self) -> Span {
        match self {
            PortIndexValue::Int(x) => x.span(),
            PortIndexValue::Path(x) => x.span(),
            PortIndexValue::Elided(span) => span.unwrap_or_else(Span::call_site),
        }
    }
}
impl From<PortIndex> for PortIndexValue {
    fn from(value: PortIndex) -> Self {
        match value {
            PortIndex::Int(x) => Self::Int(x),
            PortIndex::Path(x) => Self::Path(x),
        }
    }
}
impl PartialEq for PortIndexValue {
    fn eq(&self, other: &Self) -> bool {
        match (self, other) {
            (Self::Int(l0), Self::Int(r0)) => l0 == r0,
            (Self::Path(l0), Self::Path(r0)) => l0 == r0,
            (Self::Elided(_), Self::Elided(_)) => true,
            _else => false,
        }
    }
}
impl Eq for PortIndexValue {}
impl PartialOrd for PortIndexValue {
    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
        Some(self.cmp(other))
    }
}
impl Ord for PortIndexValue {
    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
        match (self, other) {
            (Self::Int(s), Self::Int(o)) => s.cmp(o),
            (Self::Path(s), Self::Path(o)) => s
                .to_token_stream()
                .to_string()
                .cmp(&o.to_token_stream().to_string()),
            (Self::Elided(_), Self::Elided(_)) => std::cmp::Ordering::Equal,
            (Self::Int(_), Self::Path(_)) => std::cmp::Ordering::Less,
            (Self::Path(_), Self::Int(_)) => std::cmp::Ordering::Greater,
            (_, Self::Elided(_)) => std::cmp::Ordering::Less,
            (Self::Elided(_), _) => std::cmp::Ordering::Greater,
        }
    }
}
impl Display for PortIndexValue {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        match self {
            PortIndexValue::Int(x) => write!(f, "{}", x.to_token_stream()),
            PortIndexValue::Path(x) => write!(f, "{}", x.to_token_stream()),
            PortIndexValue::Elided(_) => write!(f, "[]"),
        }
    }
}
pub fn build_hfcode(
    hf_code: HfCode,
    root: &TokenStream,
) -> (Option<(DfirGraph, TokenStream)>, Vec<Diagnostic>) {
    let flat_graph_builder = FlatGraphBuilder::from_hfcode(hf_code);
    let (mut flat_graph, uses, mut diagnostics) = flat_graph_builder.build();
    if !diagnostics.iter().any(Diagnostic::is_error) {
        if let Err(diagnostic) = flat_graph.merge_modules() {
            diagnostics.push(diagnostic);
            return (None, diagnostics);
        }
        eliminate_extra_unions_tees(&mut flat_graph);
        match partition_graph(flat_graph) {
            Ok(partitioned_graph) => {
                let code = partitioned_graph.as_code(
                    root,
                    true,
                    quote::quote! { #( #uses )* },
                    &mut diagnostics,
                );
                if !diagnostics.iter().any(Diagnostic::is_error) {
                    return (Some((partitioned_graph, code)), diagnostics);
                }
            }
            Err(diagnostic) => diagnostics.push(diagnostic),
        }
    }
    (None, diagnostics)
}
fn change_spans(tokens: TokenStream, span: Span) -> TokenStream {
    use proc_macro2::{Group, TokenTree};
    tokens
        .into_iter()
        .map(|token| match token {
            TokenTree::Group(mut group) => {
                group.set_span(span);
                TokenTree::Group(Group::new(
                    group.delimiter(),
                    change_spans(group.stream(), span),
                ))
            }
            TokenTree::Ident(mut ident) => {
                ident.set_span(span.resolved_at(ident.span()));
                TokenTree::Ident(ident)
            }
            TokenTree::Punct(mut punct) => {
                punct.set_span(span);
                TokenTree::Punct(punct)
            }
            TokenTree::Literal(mut literal) => {
                literal.set_span(span);
                TokenTree::Literal(literal)
            }
        })
        .collect()
}