Skip to main content

dfir_lang/graph/
meta_graph.rs

1#![warn(missing_docs)]
2
3extern crate proc_macro;
4
5use std::collections::{BTreeMap, BTreeSet};
6use std::fmt::Debug;
7use std::iter::FusedIterator;
8
9use itertools::Itertools;
10use proc_macro2::{Ident, Literal, Span, TokenStream};
11use quote::{ToTokens, format_ident, quote, quote_spanned};
12use serde::{Deserialize, Serialize};
13use slotmap::{Key, SecondaryMap, SlotMap, SparseSecondaryMap};
14use syn::spanned::Spanned;
15
16use super::graph_write::{Dot, GraphWrite, Mermaid};
17use super::ops::{
18    DelayType, OPERATORS, OperatorWriteOutput, WriteContextArgs, find_op_op_constraints,
19    null_write_iterator_fn,
20};
21use super::{
22    CONTEXT, Color, DiMulGraph, GRAPH, GraphEdgeId, GraphLoopId, GraphNode, GraphNodeId,
23    GraphSubgraphId, HANDOFF_NODE_STR, MODULE_BOUNDARY_NODE_STR, OperatorInstance, PortIndexValue,
24    Varname, change_spans, get_operator_generics,
25};
26use crate::diagnostic::{Diagnostic, Diagnostics, Level};
27use crate::pretty_span::{PrettyRowCol, PrettySpan};
28use crate::process_singletons;
29
30/// An abstract "meta graph" representation of a DFIR graph.
31///
32/// Can be with or without subgraph partitioning, stratification, and handoff insertion. This is
33/// the meta graph used for generating Rust source code in macros from DFIR sytnax.
34///
35/// This struct has a lot of methods for manipulating the graph, vaguely grouped together in
36/// separate `impl` blocks. You might notice a few particularly specific arbitray-seeming methods
37/// in here--those are just what was needed for the compilation algorithms. If you need another
38/// method then add it.
39#[derive(Default, Debug, Serialize, Deserialize)]
40pub struct DfirGraph {
41    /// Each node type (operator or handoff).
42    nodes: SlotMap<GraphNodeId, GraphNode>,
43
44    /// Instance data corresponding to each operator node.
45    /// This field will be empty after deserialization.
46    #[serde(skip)]
47    operator_instances: SecondaryMap<GraphNodeId, OperatorInstance>,
48    /// Debugging/tracing tag for each operator node.
49    operator_tag: SecondaryMap<GraphNodeId, String>,
50    /// Graph data structure (two-way adjacency list).
51    graph: DiMulGraph<GraphNodeId, GraphEdgeId>,
52    /// Input and output port for each edge.
53    ports: SecondaryMap<GraphEdgeId, (PortIndexValue, PortIndexValue)>,
54
55    /// Which loop a node belongs to (or none for top-level).
56    node_loops: SecondaryMap<GraphNodeId, GraphLoopId>,
57    /// Which nodes belong to each loop.
58    loop_nodes: SlotMap<GraphLoopId, Vec<GraphNodeId>>,
59    /// For the loop, what is its parent (`None` for top-level).
60    loop_parent: SparseSecondaryMap<GraphLoopId, GraphLoopId>,
61    /// What loops are at the root.
62    root_loops: Vec<GraphLoopId>,
63    /// For the loop, what are its child loops.
64    loop_children: SecondaryMap<GraphLoopId, Vec<GraphLoopId>>,
65
66    /// Which subgraph each node belongs to.
67    node_subgraph: SecondaryMap<GraphNodeId, GraphSubgraphId>,
68
69    /// Which nodes belong to each subgraph.
70    subgraph_nodes: SlotMap<GraphSubgraphId, Vec<GraphNodeId>>,
71
72    /// Resolved singletons varnames references, per node.
73    node_singleton_references: SparseSecondaryMap<GraphNodeId, Vec<Option<GraphNodeId>>>,
74    /// What variable name each graph node belongs to (if any). For debugging (graph writing) purposes only.
75    node_varnames: SparseSecondaryMap<GraphNodeId, Varname>,
76
77    /// Delay type for handoff nodes that represent tick-boundary back-edges.
78    /// Set by `order_subgraphs` for `defer_tick` / `defer_tick_lazy`, either on handoff nodes
79    /// it injects or on existing handoff nodes that it marks as tick-boundary back-edges.
80    handoff_delay_type: SparseSecondaryMap<GraphNodeId, DelayType>,
81
82    /// Whether each node produces exactly one item (is a singleton).
83    /// Computed by propagation: a node is a singleton if it has `has_singleton_output`,
84    /// or if all its inputs are singletons and it has `preserves_singleton`.
85    node_is_singleton: SparseSecondaryMap<GraphNodeId, ()>,
86}
87
88/// Basic methods.
89impl DfirGraph {
90    /// Create a new empty graph.
91    pub fn new() -> Self {
92        Default::default()
93    }
94}
95
96/// Node methods.
97impl DfirGraph {
98    /// Get a node with its operator instance (if applicable).
99    pub fn node(&self, node_id: GraphNodeId) -> &GraphNode {
100        self.nodes.get(node_id).expect("Node not found.")
101    }
102
103    /// Get the `OperatorInstance` for a given node. Node must be an operator and have an
104    /// `OperatorInstance` present, otherwise will return `None`.
105    ///
106    /// Note that no operator instances will be persent after deserialization.
107    pub fn node_op_inst(&self, node_id: GraphNodeId) -> Option<&OperatorInstance> {
108        self.operator_instances.get(node_id)
109    }
110
111    /// Get the debug variable name attached to a graph node.
112    pub fn node_varname(&self, node_id: GraphNodeId) -> Option<&Varname> {
113        self.node_varnames.get(node_id)
114    }
115
116    /// Get subgraph for node.
117    pub fn node_subgraph(&self, node_id: GraphNodeId) -> Option<GraphSubgraphId> {
118        self.node_subgraph.get(node_id).copied()
119    }
120
121    /// Degree into a node, i.e. the number of predecessors.
122    pub fn node_degree_in(&self, node_id: GraphNodeId) -> usize {
123        self.graph.degree_in(node_id)
124    }
125
126    /// Degree out of a node, i.e. the number of successors.
127    pub fn node_degree_out(&self, node_id: GraphNodeId) -> usize {
128        self.graph.degree_out(node_id)
129    }
130
131    /// Successors, iterator of `(GraphEdgeId, GraphNodeId)` of outgoing edges.
132    pub fn node_successors(
133        &self,
134        src: GraphNodeId,
135    ) -> impl '_
136    + DoubleEndedIterator<Item = (GraphEdgeId, GraphNodeId)>
137    + ExactSizeIterator
138    + FusedIterator
139    + Clone
140    + Debug {
141        self.graph.successors(src)
142    }
143
144    /// Predecessors, iterator of `(GraphEdgeId, GraphNodeId)` of incoming edges.
145    pub fn node_predecessors(
146        &self,
147        dst: GraphNodeId,
148    ) -> impl '_
149    + DoubleEndedIterator<Item = (GraphEdgeId, GraphNodeId)>
150    + ExactSizeIterator
151    + FusedIterator
152    + Clone
153    + Debug {
154        self.graph.predecessors(dst)
155    }
156
157    /// Successor edges, iterator of `GraphEdgeId` of outgoing edges.
158    pub fn node_successor_edges(
159        &self,
160        src: GraphNodeId,
161    ) -> impl '_
162    + DoubleEndedIterator<Item = GraphEdgeId>
163    + ExactSizeIterator
164    + FusedIterator
165    + Clone
166    + Debug {
167        self.graph.successor_edges(src)
168    }
169
170    /// Predecessor edges, iterator of `GraphEdgeId` of incoming edges.
171    pub fn node_predecessor_edges(
172        &self,
173        dst: GraphNodeId,
174    ) -> impl '_
175    + DoubleEndedIterator<Item = GraphEdgeId>
176    + ExactSizeIterator
177    + FusedIterator
178    + Clone
179    + Debug {
180        self.graph.predecessor_edges(dst)
181    }
182
183    /// Successor nodes, iterator of `GraphNodeId`.
184    pub fn node_successor_nodes(
185        &self,
186        src: GraphNodeId,
187    ) -> impl '_
188    + DoubleEndedIterator<Item = GraphNodeId>
189    + ExactSizeIterator
190    + FusedIterator
191    + Clone
192    + Debug {
193        self.graph.successor_vertices(src)
194    }
195
196    /// Predecessor nodes, iterator of `GraphNodeId`.
197    pub fn node_predecessor_nodes(
198        &self,
199        dst: GraphNodeId,
200    ) -> impl '_
201    + DoubleEndedIterator<Item = GraphNodeId>
202    + ExactSizeIterator
203    + FusedIterator
204    + Clone
205    + Debug {
206        self.graph.predecessor_vertices(dst)
207    }
208
209    /// Iterator of node IDs `GraphNodeId`.
210    pub fn node_ids(&self) -> slotmap::basic::Keys<'_, GraphNodeId, GraphNode> {
211        self.nodes.keys()
212    }
213
214    /// Iterator over `(GraphNodeId, &Node)` pairs.
215    pub fn nodes(&self) -> slotmap::basic::Iter<'_, GraphNodeId, GraphNode> {
216        self.nodes.iter()
217    }
218
219    /// Insert a node, assigning the given varname.
220    pub fn insert_node(
221        &mut self,
222        node: GraphNode,
223        varname_opt: Option<Ident>,
224        loop_opt: Option<GraphLoopId>,
225    ) -> GraphNodeId {
226        let node_id = self.nodes.insert(node);
227        if let Some(varname) = varname_opt {
228            self.node_varnames.insert(node_id, Varname(varname));
229        }
230        if let Some(loop_id) = loop_opt {
231            self.node_loops.insert(node_id, loop_id);
232            self.loop_nodes[loop_id].push(node_id);
233        }
234        node_id
235    }
236
237    /// Insert an operator instance for the given node. Panics if already set.
238    pub fn insert_node_op_inst(&mut self, node_id: GraphNodeId, op_inst: OperatorInstance) {
239        assert!(matches!(
240            self.nodes.get(node_id),
241            Some(GraphNode::Operator(_))
242        ));
243        let old_inst = self.operator_instances.insert(node_id, op_inst);
244        assert!(old_inst.is_none());
245    }
246
247    /// Assign all operator instances if not set. Write diagnostic messages/errors into `diagnostics`.
248    pub fn insert_node_op_insts_all(&mut self, diagnostics: &mut Diagnostics) {
249        let mut op_insts = Vec::new();
250        for (node_id, node) in self.nodes() {
251            let GraphNode::Operator(operator) = node else {
252                continue;
253            };
254            if self.node_op_inst(node_id).is_some() {
255                continue;
256            };
257
258            // Op constraints.
259            let Some(op_constraints) = find_op_op_constraints(operator) else {
260                diagnostics.push(Diagnostic::spanned(
261                    operator.path.span(),
262                    Level::Error,
263                    format!("Unknown operator `{}`", operator.name_string()),
264                ));
265                continue;
266            };
267
268            // Input and output ports.
269            let (input_ports, output_ports) = {
270                let mut input_edges: Vec<(&PortIndexValue, GraphNodeId)> = self
271                    .node_predecessors(node_id)
272                    .map(|(edge_id, pred_id)| (self.edge_ports(edge_id).1, pred_id))
273                    .collect();
274                // Ensure sorted by port index.
275                input_edges.sort();
276                let input_ports: Vec<PortIndexValue> = input_edges
277                    .into_iter()
278                    .map(|(port, _pred)| port)
279                    .cloned()
280                    .collect();
281
282                // Collect output arguments (successors).
283                let mut output_edges: Vec<(&PortIndexValue, GraphNodeId)> = self
284                    .node_successors(node_id)
285                    .map(|(edge_id, succ)| (self.edge_ports(edge_id).0, succ))
286                    .collect();
287                // Ensure sorted by port index.
288                output_edges.sort();
289                let output_ports: Vec<PortIndexValue> = output_edges
290                    .into_iter()
291                    .map(|(port, _succ)| port)
292                    .cloned()
293                    .collect();
294
295                (input_ports, output_ports)
296            };
297
298            // Generic arguments.
299            let generics = get_operator_generics(diagnostics, operator);
300            // Generic argument errors.
301            {
302                // Span of `generic_args` (if it exists), otherwise span of the operator name.
303                let generics_span = generics
304                    .generic_args
305                    .as_ref()
306                    .map(Spanned::span)
307                    .unwrap_or_else(|| operator.path.span());
308
309                if !op_constraints
310                    .persistence_args
311                    .contains(&generics.persistence_args.len())
312                {
313                    diagnostics.push(Diagnostic::spanned(
314                        generics.persistence_args_span().unwrap_or(generics_span),
315                        Level::Error,
316                        format!(
317                            "`{}` should have {} persistence lifetime arguments, actually has {}.",
318                            op_constraints.name,
319                            op_constraints.persistence_args.human_string(),
320                            generics.persistence_args.len()
321                        ),
322                    ));
323                }
324                if !op_constraints.type_args.contains(&generics.type_args.len()) {
325                    diagnostics.push(Diagnostic::spanned(
326                        generics.type_args_span().unwrap_or(generics_span),
327                        Level::Error,
328                        format!(
329                            "`{}` should have {} generic type arguments, actually has {}.",
330                            op_constraints.name,
331                            op_constraints.type_args.human_string(),
332                            generics.type_args.len()
333                        ),
334                    ));
335                }
336            }
337
338            op_insts.push((
339                node_id,
340                OperatorInstance {
341                    op_constraints,
342                    input_ports,
343                    output_ports,
344                    singletons_referenced: operator.singletons_referenced.clone(),
345                    generics,
346                    arguments_pre: operator.args.clone(),
347                    arguments_raw: operator.args_raw.clone(),
348                },
349            ));
350        }
351
352        for (node_id, op_inst) in op_insts {
353            self.insert_node_op_inst(node_id, op_inst);
354        }
355    }
356
357    /// Inserts a node between two existing nodes connected by the given `edge_id`.
358    ///
359    /// `edge`: (src, dst, dst_idx)
360    ///
361    /// Before: A (src) ------------> B (dst)
362    /// After:  A (src) -> X (new) -> B (dst)
363    ///
364    /// Returns the ID of X & ID of edge OUT of X.
365    ///
366    /// Note that both the edges will be new and `edge_id` will be removed. Both new edges will
367    /// get the edge type of the original edge.
368    pub fn insert_intermediate_node(
369        &mut self,
370        edge_id: GraphEdgeId,
371        new_node: GraphNode,
372    ) -> (GraphNodeId, GraphEdgeId) {
373        let span = Some(new_node.span());
374
375        // Make corresponding operator instance (if `node` is an operator).
376        let op_inst_opt = 'oc: {
377            let GraphNode::Operator(operator) = &new_node else {
378                break 'oc None;
379            };
380            let Some(op_constraints) = find_op_op_constraints(operator) else {
381                break 'oc None;
382            };
383            let (input_port, output_port) = self.ports.get(edge_id).cloned().unwrap();
384
385            let mut dummy_diagnostics = Diagnostics::new();
386            let generics = get_operator_generics(&mut dummy_diagnostics, operator);
387            assert!(dummy_diagnostics.is_empty());
388
389            Some(OperatorInstance {
390                op_constraints,
391                input_ports: vec![input_port],
392                output_ports: vec![output_port],
393                singletons_referenced: operator.singletons_referenced.clone(),
394                generics,
395                arguments_pre: operator.args.clone(),
396                arguments_raw: operator.args_raw.clone(),
397            })
398        };
399
400        // Insert new `node`.
401        let node_id = self.nodes.insert(new_node);
402        // Insert corresponding `OperatorInstance` if applicable.
403        if let Some(op_inst) = op_inst_opt {
404            self.operator_instances.insert(node_id, op_inst);
405        }
406        // Update edges to insert node within `edge_id`.
407        let (e0, e1) = self
408            .graph
409            .insert_intermediate_vertex(node_id, edge_id)
410            .unwrap();
411
412        // Update corresponding ports.
413        let (src_idx, dst_idx) = self.ports.remove(edge_id).unwrap();
414        self.ports
415            .insert(e0, (src_idx, PortIndexValue::Elided(span)));
416        self.ports
417            .insert(e1, (PortIndexValue::Elided(span), dst_idx));
418
419        (node_id, e1)
420    }
421
422    /// Remove the node `node_id` but preserves and connects the single predecessor and single successor.
423    /// Panics if the node does not have exactly one predecessor and one successor, or is not in the graph.
424    pub fn remove_intermediate_node(&mut self, node_id: GraphNodeId) {
425        assert_eq!(
426            1,
427            self.node_degree_in(node_id),
428            "Removed intermediate node must have one predecessor"
429        );
430        assert_eq!(
431            1,
432            self.node_degree_out(node_id),
433            "Removed intermediate node must have one successor"
434        );
435        assert!(
436            self.node_subgraph.is_empty() && self.subgraph_nodes.is_empty(),
437            "Should not remove intermediate node after subgraph partitioning"
438        );
439
440        assert!(self.nodes.remove(node_id).is_some());
441        let (new_edge_id, (pred_edge_id, succ_edge_id)) =
442            self.graph.remove_intermediate_vertex(node_id).unwrap();
443        self.operator_instances.remove(node_id);
444        self.node_varnames.remove(node_id);
445
446        let (src_port, _) = self.ports.remove(pred_edge_id).unwrap();
447        let (_, dst_port) = self.ports.remove(succ_edge_id).unwrap();
448        self.ports.insert(new_edge_id, (src_port, dst_port));
449    }
450
451    /// Helper method: determine the "color" (pull vs push) of a node based on its in and out degree,
452    /// excluding reference edges. If linear (1 in, 1 out), color is `None`, indicating it can be
453    /// either push or pull.
454    ///
455    /// Note that this does NOT consider `DelayType` barriers (which generally implies `Pull`).
456    pub(crate) fn node_color(&self, node_id: GraphNodeId) -> Option<Color> {
457        if matches!(self.node(node_id), GraphNode::Handoff { .. }) {
458            return Some(Color::Hoff);
459        }
460
461        // TODO(shadaj): this is a horrible hack
462        if let GraphNode::Operator(op) = self.node(node_id)
463            && (op.name_string() == "resolve_futures_blocking"
464                || op.name_string() == "resolve_futures_blocking_ordered")
465        {
466            return Some(Color::Push);
467        }
468
469        // In-degree, excluding ref-edges.
470        let inn_degree = self.node_predecessor_nodes(node_id).len();
471        // Out-degree excluding ref-edges.
472        let out_degree = self.node_successor_nodes(node_id).len();
473
474        match (inn_degree, out_degree) {
475            (0, 0) => None, // Generally should not happen, "Degenerate subgraph detected".
476            (0, 1) => Some(Color::Pull),
477            (1, 0) => Some(Color::Push),
478            (1, 1) => None, // Linear, can be either push or pull.
479            (_many, 0 | 1) => Some(Color::Pull),
480            (0 | 1, _many) => Some(Color::Push),
481            (_many, _to_many) => Some(Color::Comp),
482        }
483    }
484
485    /// Set the operator tag (for debugging/tracing).
486    pub fn set_operator_tag(&mut self, node_id: GraphNodeId, tag: String) {
487        self.operator_tag.insert(node_id, tag);
488    }
489}
490
491/// Singleton references.
492impl DfirGraph {
493    /// Set the singletons referenced for the `node_id` operator. Each reference corresponds to the
494    /// same index in the [`crate::parse::Operator::singletons_referenced`] vec.
495    pub fn set_node_singleton_references(
496        &mut self,
497        node_id: GraphNodeId,
498        singletons_referenced: Vec<Option<GraphNodeId>>,
499    ) -> Option<Vec<Option<GraphNodeId>>> {
500        self.node_singleton_references
501            .insert(node_id, singletons_referenced)
502    }
503
504    /// Gets the singletons referenced by a node. Returns an empty iterator for non-operators and
505    /// operators that do not reference singletons.
506    pub fn node_singleton_references(&self, node_id: GraphNodeId) -> &[Option<GraphNodeId>] {
507        self.node_singleton_references
508            .get(node_id)
509            .map(std::ops::Deref::deref)
510            .unwrap_or_default()
511    }
512
513    /// Returns whether the given node produces exactly one item (is a singleton).
514    pub fn node_is_singleton(&self, node_id: GraphNodeId) -> bool {
515        self.node_is_singleton.contains_key(node_id)
516    }
517
518    /// Computes the `node_is_singleton` field by propagation.
519    /// A node is a singleton if:
520    /// - Its operator has `has_singleton_output: true`, OR
521    /// - All its predecessor nodes are singletons AND its operator has `preserves_singleton: true`.
522    pub fn compute_node_singletons(&mut self) {
523        // Iterate in topological order (node_ids are already topo-sorted after partitioning).
524        let node_ids: Vec<_> = self.node_ids().collect();
525        for node_id in node_ids {
526            let Some(op_inst) = self.operator_instances.get(node_id) else {
527                continue;
528            };
529            let op_constraints = op_inst.op_constraints;
530
531            if op_constraints.has_singleton_output {
532                self.node_is_singleton.insert(node_id, ());
533            } else if op_constraints.preserves_singleton {
534                // Check if all predecessors are singletons.
535                let all_preds_singleton = self
536                    .node_predecessor_nodes(node_id)
537                    .all(|pred_id| self.node_is_singleton.contains_key(pred_id));
538                // Must have at least one predecessor to inherit singleton status.
539                let has_preds = self.node_predecessor_nodes(node_id).next().is_some();
540                if has_preds && all_preds_singleton {
541                    self.node_is_singleton.insert(node_id, ());
542                }
543            }
544        }
545    }
546}
547
548/// Module methods.
549impl DfirGraph {
550    /// When modules are imported into a flat graph, they come with an input and output ModuleBoundary node.
551    /// The partitioner doesn't understand these nodes and will panic if it encounters them.
552    /// merge_modules removes them from the graph, stitching the input and ouput sides of the ModuleBondaries based on their ports
553    /// For example:
554    ///     source_iter([]) -> \[myport\]ModuleBoundary(input)\[my_port\] -> map(|x| x) -> ModuleBoundary(output) -> null();
555    /// in the above eaxmple, the \[myport\] port will be used to connect the source_iter with the map that is inside of the module.
556    /// The output module boundary has elided ports, this is also used to match up the input/output across the module boundary.
557    pub fn merge_modules(&mut self) -> Result<(), Diagnostic> {
558        let mod_bound_nodes = self
559            .nodes()
560            .filter(|(_nid, node)| matches!(node, GraphNode::ModuleBoundary { .. }))
561            .map(|(nid, _node)| nid)
562            .collect::<Vec<_>>();
563
564        for mod_bound_node in mod_bound_nodes {
565            self.remove_module_boundary(mod_bound_node)?;
566        }
567
568        Ok(())
569    }
570
571    /// see `merge_modules`
572    /// This function removes a singular module boundary from the graph and performs the necessary stitching to fix the graph afterward.
573    /// `merge_modules` calls this function for each module boundary in the graph.
574    fn remove_module_boundary(&mut self, mod_bound_node: GraphNodeId) -> Result<(), Diagnostic> {
575        assert!(
576            self.node_subgraph.is_empty() && self.subgraph_nodes.is_empty(),
577            "Should not remove intermediate node after subgraph partitioning"
578        );
579
580        let mut mod_pred_ports = BTreeMap::new();
581        let mut mod_succ_ports = BTreeMap::new();
582
583        for mod_out_edge in self.node_predecessor_edges(mod_bound_node) {
584            let (pred_port, succ_port) = self.edge_ports(mod_out_edge);
585            mod_pred_ports.insert(succ_port.clone(), (mod_out_edge, pred_port.clone()));
586        }
587
588        for mod_inn_edge in self.node_successor_edges(mod_bound_node) {
589            let (pred_port, succ_port) = self.edge_ports(mod_inn_edge);
590            mod_succ_ports.insert(pred_port.clone(), (mod_inn_edge, succ_port.clone()));
591        }
592
593        if mod_pred_ports.keys().collect::<BTreeSet<_>>()
594            != mod_succ_ports.keys().collect::<BTreeSet<_>>()
595        {
596            // get module boundary node
597            let GraphNode::ModuleBoundary { input, import_expr } = self.node(mod_bound_node) else {
598                panic!();
599            };
600
601            if *input {
602                return Err(Diagnostic {
603                    span: *import_expr,
604                    level: Level::Error,
605                    message: format!(
606                        "The ports into the module did not match. input: {:?}, expected: {:?}",
607                        mod_pred_ports.keys().map(|x| x.to_string()).join(", "),
608                        mod_succ_ports.keys().map(|x| x.to_string()).join(", ")
609                    ),
610                });
611            } else {
612                return Err(Diagnostic {
613                    span: *import_expr,
614                    level: Level::Error,
615                    message: format!(
616                        "The ports out of the module did not match. output: {:?}, expected: {:?}",
617                        mod_succ_ports.keys().map(|x| x.to_string()).join(", "),
618                        mod_pred_ports.keys().map(|x| x.to_string()).join(", "),
619                    ),
620                });
621            }
622        }
623
624        for (port, (pred_edge, pred_port)) in mod_pred_ports {
625            let (succ_edge, succ_port) = mod_succ_ports.remove(&port).unwrap();
626
627            let (src, _) = self.edge(pred_edge);
628            let (_, dst) = self.edge(succ_edge);
629            self.remove_edge(pred_edge);
630            self.remove_edge(succ_edge);
631
632            let new_edge_id = self.graph.insert_edge(src, dst);
633            self.ports.insert(new_edge_id, (pred_port, succ_port));
634        }
635
636        self.graph.remove_vertex(mod_bound_node);
637        self.nodes.remove(mod_bound_node);
638
639        Ok(())
640    }
641}
642
643/// Edge methods.
644impl DfirGraph {
645    /// Get the `src` and `dst` for an edge: `(src GraphNodeId, dst GraphNodeId)`.
646    pub fn edge(&self, edge_id: GraphEdgeId) -> (GraphNodeId, GraphNodeId) {
647        let (src, dst) = self.graph.edge(edge_id).expect("Edge not found.");
648        (src, dst)
649    }
650
651    /// Get the source and destination ports for an edge: `(src &PortIndexValue, dst &PortIndexValue)`.
652    pub fn edge_ports(&self, edge_id: GraphEdgeId) -> (&PortIndexValue, &PortIndexValue) {
653        let (src_port, dst_port) = self.ports.get(edge_id).expect("Edge not found.");
654        (src_port, dst_port)
655    }
656
657    /// Iterator of all edge IDs `GraphEdgeId`.
658    pub fn edge_ids(&self) -> slotmap::basic::Keys<'_, GraphEdgeId, (GraphNodeId, GraphNodeId)> {
659        self.graph.edge_ids()
660    }
661
662    /// Iterator over all edges: `(GraphEdgeId, (src GraphNodeId, dst GraphNodeId))`.
663    pub fn edges(
664        &self,
665    ) -> impl '_
666    + ExactSizeIterator<Item = (GraphEdgeId, (GraphNodeId, GraphNodeId))>
667    + FusedIterator
668    + Clone
669    + Debug {
670        self.graph.edges()
671    }
672
673    /// Insert an edge between nodes thru the given ports.
674    pub fn insert_edge(
675        &mut self,
676        src: GraphNodeId,
677        src_port: PortIndexValue,
678        dst: GraphNodeId,
679        dst_port: PortIndexValue,
680    ) -> GraphEdgeId {
681        let edge_id = self.graph.insert_edge(src, dst);
682        self.ports.insert(edge_id, (src_port, dst_port));
683        edge_id
684    }
685
686    /// Removes an edge and its corresponding ports and edge type info.
687    pub fn remove_edge(&mut self, edge: GraphEdgeId) {
688        let (_src, _dst) = self.graph.remove_edge(edge).unwrap();
689        let (_src_port, _dst_port) = self.ports.remove(edge).unwrap();
690    }
691}
692
693/// Subgraph methods.
694impl DfirGraph {
695    /// Nodes belonging to the given subgraph.
696    pub fn subgraph(&self, subgraph_id: GraphSubgraphId) -> &Vec<GraphNodeId> {
697        self.subgraph_nodes
698            .get(subgraph_id)
699            .expect("Subgraph not found.")
700    }
701
702    /// Iterator over all subgraph IDs.
703    pub fn subgraph_ids(&self) -> slotmap::basic::Keys<'_, GraphSubgraphId, Vec<GraphNodeId>> {
704        self.subgraph_nodes.keys()
705    }
706
707    /// Iterator over all subgraphs, ID and members: `(GraphSubgraphId, Vec<GraphNodeId>)`.
708    pub fn subgraphs(&self) -> slotmap::basic::Iter<'_, GraphSubgraphId, Vec<GraphNodeId>> {
709        self.subgraph_nodes.iter()
710    }
711
712    /// Create a subgraph consisting of `node_ids`. Returns an error if any of the nodes are already in a subgraph.
713    pub fn insert_subgraph(
714        &mut self,
715        node_ids: Vec<GraphNodeId>,
716    ) -> Result<GraphSubgraphId, (GraphNodeId, GraphSubgraphId)> {
717        // Check none are already in subgraphs
718        for &node_id in node_ids.iter() {
719            if let Some(&old_sg_id) = self.node_subgraph.get(node_id) {
720                return Err((node_id, old_sg_id));
721            }
722        }
723        let subgraph_id = self.subgraph_nodes.insert_with_key(|sg_id| {
724            for &node_id in node_ids.iter() {
725                self.node_subgraph.insert(node_id, sg_id);
726            }
727            node_ids
728        });
729
730        Ok(subgraph_id)
731    }
732
733    /// Removes a node from its subgraph. Returns true if the node was in a subgraph.
734    pub fn remove_from_subgraph(&mut self, node_id: GraphNodeId) -> bool {
735        if let Some(old_sg_id) = self.node_subgraph.remove(node_id) {
736            self.subgraph_nodes[old_sg_id].retain(|&other_node_id| other_node_id != node_id);
737            true
738        } else {
739            false
740        }
741    }
742
743    /// Gets the delay type for a handoff node, if set.
744    pub fn handoff_delay_type(&self, node_id: GraphNodeId) -> Option<DelayType> {
745        self.handoff_delay_type.get(node_id).copied()
746    }
747
748    /// Sets the delay type for a handoff node.
749    pub fn set_handoff_delay_type(&mut self, node_id: GraphNodeId, delay_type: DelayType) {
750        self.handoff_delay_type.insert(node_id, delay_type);
751    }
752
753    /// Helper: finds the first index in `subgraph_nodes` where it transitions from pull to push.
754    fn find_pull_to_push_idx(&self, subgraph_nodes: &[GraphNodeId]) -> usize {
755        subgraph_nodes
756            .iter()
757            .position(|&node_id| {
758                self.node_color(node_id)
759                    .is_some_and(|color| Color::Pull != color)
760            })
761            .unwrap_or(subgraph_nodes.len())
762    }
763}
764
765/// Display/output methods.
766impl DfirGraph {
767    /// Helper to generate a deterministic `Ident` for the given node.
768    fn node_as_ident(&self, node_id: GraphNodeId, is_pred: bool) -> Ident {
769        let name = match &self.nodes[node_id] {
770            GraphNode::Operator(_) => format!("op_{:?}", node_id.data()),
771            GraphNode::Handoff { .. } => format!(
772                "hoff_{:?}_{}",
773                node_id.data(),
774                if is_pred { "recv" } else { "send" }
775            ),
776            GraphNode::ModuleBoundary { .. } => panic!(),
777        };
778        let span = match (is_pred, &self.nodes[node_id]) {
779            (_, GraphNode::Operator(operator)) => operator.span(),
780            (true, &GraphNode::Handoff { src_span, .. }) => src_span,
781            (false, &GraphNode::Handoff { dst_span, .. }) => dst_span,
782            (_, GraphNode::ModuleBoundary { .. }) => panic!(),
783        };
784        Ident::new(&name, span)
785    }
786
787    /// Helper to generate the main buffer `Ident` for a handoff node.
788    fn hoff_buf_ident(&self, hoff_id: GraphNodeId, span: Span) -> Ident {
789        Ident::new(&format!("hoff_{:?}_buf", hoff_id.data()), span)
790    }
791
792    /// Helper to generate the back (double-buffer) `Ident` for a handoff node.
793    fn hoff_back_ident(&self, hoff_id: GraphNodeId, span: Span) -> Ident {
794        Ident::new(&format!("hoff_{:?}_back", hoff_id.data()), span)
795    }
796
797    /// For per-node singleton references. Helper to generate a deterministic `Ident` for the given node.
798    fn node_as_singleton_ident(&self, node_id: GraphNodeId, span: Span) -> Ident {
799        Ident::new(&format!("singleton_op_{:?}", node_id.data()), span)
800    }
801
802    /// Resolve the singletons via [`Self::node_singleton_references`] for the given `node_id`.
803    fn helper_resolve_singletons(&self, node_id: GraphNodeId, span: Span) -> Vec<Ident> {
804        self.node_singleton_references(node_id)
805            .iter()
806            .map(|singleton_node_id| {
807                // TODO(mingwei): this `expect` should be caught in error checking
808                self.node_as_singleton_ident(
809                    singleton_node_id
810                        .expect("Expected singleton to be resolved but was not, this is a bug."),
811                    span,
812                )
813            })
814            .collect::<Vec<_>>()
815    }
816
817    /// Returns each subgraph's receive and send handoffs.
818    /// `Map<GraphSubgraphId, (recv handoffs, send handoffs)>`
819    fn helper_collect_subgraph_handoffs(
820        &self,
821    ) -> SecondaryMap<GraphSubgraphId, (Vec<GraphNodeId>, Vec<GraphNodeId>)> {
822        // Get data on handoff src and dst subgraphs.
823        let mut subgraph_handoffs: SecondaryMap<
824            GraphSubgraphId,
825            (Vec<GraphNodeId>, Vec<GraphNodeId>),
826        > = self
827            .subgraph_nodes
828            .keys()
829            .map(|k| (k, Default::default()))
830            .collect();
831
832        // For each handoff node, add it to the `send`/`recv` lists for the corresponding subgraphs.
833        for (hoff_id, node) in self.nodes() {
834            if !matches!(node, GraphNode::Handoff { .. }) {
835                continue;
836            }
837            // Receivers from the handoff. (Should really only be one).
838            for (_edge, succ_id) in self.node_successors(hoff_id) {
839                let succ_sg = self.node_subgraph(succ_id).unwrap();
840                subgraph_handoffs[succ_sg].0.push(hoff_id);
841            }
842            // Senders into the handoff. (Should really only be one).
843            for (_edge, pred_id) in self.node_predecessors(hoff_id) {
844                let pred_sg = self.node_subgraph(pred_id).unwrap();
845                subgraph_handoffs[pred_sg].1.push(hoff_id);
846            }
847        }
848
849        subgraph_handoffs
850    }
851
852    /// Emit this graph as runnable Rust source code tokens that execute inline.
853    /// Generates a flat `async move |df: &mut Context|` closure where subgraph
854    /// blocks are inlined in topological order, using local `Vec<T>` buffers
855    /// instead of runtime handoffs. Each call to the closure runs one tick.
856    ///
857    /// The generated code block evaluates to a `Dfir` instance wrapping the
858    /// closure. Operator prologues run at construction time on the `Context`
859    /// before it is moved into `Dfir::new`. `Dfir` provides the `Context`
860    /// to the closure on each tick run.
861    ///
862    /// # Errors
863    ///
864    /// Returns all diagnostics as `Err(diagnostics)` if any are errors
865    /// (leaving `&mut diagnostics` empty).
866    pub fn as_code(
867        &self,
868        root: &TokenStream,
869        include_type_guards: bool,
870        prefix: TokenStream,
871        diagnostics: &mut Diagnostics,
872    ) -> Result<TokenStream, Diagnostics> {
873        self.as_code_with_options(root, include_type_guards, true, prefix, diagnostics)
874    }
875
876    /// Like [`Self::as_code`], but with `include_meta` controlling whether
877    /// the runtime meta graph + diagnostics JSON blobs are baked into the
878    /// generated `Dfir::new(...)` call.
879    ///
880    /// The simulator calls Dfir::new() on each iteration, and as a part of that
881    /// it does parsing of the metagraph and diganostics blob. One of them causes spans to get allocated,
882    /// each time a span is allocated, some threadlocal u32 is being incremented, and, on a long simulator run,
883    /// the u32 overflows and panics.
884    pub fn as_code_with_options(
885        &self,
886        root: &TokenStream,
887        include_type_guards: bool,
888        include_meta: bool,
889        prefix: TokenStream,
890        diagnostics: &mut Diagnostics,
891    ) -> Result<TokenStream, Diagnostics> {
892        let df = Ident::new(GRAPH, Span::call_site());
893        let context = Ident::new(CONTEXT, Span::call_site());
894
895        // 1. Generate local Vec buffers for each handoff node.
896        let handoff_nodes: Vec<_> = self
897            .nodes
898            .iter()
899            .filter_map(|(node_id, node)| match node {
900                GraphNode::Operator(_) => None,
901                &GraphNode::Handoff { src_span, dst_span } => Some((node_id, (src_span, dst_span))),
902                GraphNode::ModuleBoundary { .. } => panic!(),
903            })
904            .collect();
905
906        let buffer_code: Vec<TokenStream> = handoff_nodes
907            .iter()
908            .map(|&(node_id, (src_span, dst_span))| {
909                let span = src_span.join(dst_span).unwrap_or(src_span);
910                let buf_ident = self.hoff_buf_ident(node_id, span);
911                quote_spanned! {span=>
912                    let mut #buf_ident: Vec<_> = Vec::new();
913                }
914            })
915            .collect();
916
917        // For tick-boundary handoffs (`defer_tick` / `defer_tick_lazy`), declare a
918        // second "back" buffer for double-buffering. At the start of each tick, the
919        // main buffer and back buffer are swapped so the consumer reads last tick's
920        // data while the producer writes to a fresh buffer.
921        let back_buffer_code: Vec<TokenStream> = handoff_nodes
922            .iter()
923            .filter(|(node_id, _)| self.handoff_delay_type(*node_id).is_some())
924            .map(|&(node_id, (src_span, dst_span))| {
925                let span = src_span.join(dst_span).unwrap_or(src_span);
926                let back_ident = self.hoff_back_ident(node_id, span);
927                quote_spanned! {span=>
928                    let mut #back_ident: Vec<_> = Vec::new();
929                }
930            })
931            .collect();
932
933        // 2. Collect subgraph handoffs (same as as_code).
934        let subgraph_handoffs = self.helper_collect_subgraph_handoffs();
935
936        // 3. Sort subgraphs topologically and collect non-lazy defer_tick buffer idents.
937        //
938        // Handoffs marked with a `DelayType` (Tick/TickLazy) are tick-boundary back-edges.
939        // These are excluded from the topo sort (no ordering constraint). Double-buffering
940        // ensures data written by the producer in tick N is only visible to the consumer
941        // in tick N+1, regardless of execution order.
942        //
943        // While iterating handoffs, we also collect buffer idents for non-lazy tick-boundary
944        // edges (defer_tick). When these buffers are non-empty at end of tick, we set
945        // can_start_tick so that run_available continues ticking.
946        let mut defer_tick_buf_idents: Vec<Ident> = Vec::new();
947        let mut back_edge_hoff_ids: BTreeSet<GraphNodeId> = BTreeSet::new();
948        let all_subgraphs = {
949            // Build predecessor map for subgraphs.
950            let mut sg_preds = SecondaryMap::<_, Vec<_>>::with_capacity(self.subgraph_nodes.len());
951            for (hoff_id, node) in self.nodes() {
952                if !matches!(node, GraphNode::Handoff { .. }) {
953                    // Not a handoff; skip.
954                    continue;
955                }
956                assert_eq!(1, self.node_successors(hoff_id).len());
957                assert_eq!(1, self.node_predecessors(hoff_id).len());
958                let (_edge_id, pred) = self.node_predecessors(hoff_id).next().unwrap();
959                let (_edge_id, succ) = self.node_successors(hoff_id).next().unwrap();
960                let pred_sg = self.node_subgraph(pred).unwrap();
961                let succ_sg = self.node_subgraph(succ).unwrap();
962                if pred_sg == succ_sg {
963                    panic!("bug: unexpected subgraph self-handoff cycle");
964                }
965                if let Some(delay_type) = self.handoff_delay_type(hoff_id) {
966                    debug_assert!(matches!(delay_type, DelayType::Tick | DelayType::TickLazy));
967                    // Tick/back-edge handoff: no ordering constraint. Double-buffering
968                    // handles the tick deferral regardless of execution order.
969                    back_edge_hoff_ids.insert(hoff_id);
970
971                    // Non-lazy tick-boundary: defer_tick (not defer_tick_lazy).
972                    if !matches!(delay_type, DelayType::TickLazy) {
973                        defer_tick_buf_idents.push(self.hoff_buf_ident(hoff_id, node.span()));
974                    }
975                } else {
976                    sg_preds.entry(succ_sg).unwrap().or_default().push(pred_sg);
977                }
978            }
979
980            // Include singleton reference edges: if node A references the
981            // singleton output of node B, then A's subgraph must run after B's.
982            for dst_id in self.node_ids() {
983                for src_ref_id in self
984                    .node_singleton_references(dst_id)
985                    .iter()
986                    .copied()
987                    .flatten()
988                {
989                    let src_sg = self
990                        .node_subgraph(src_ref_id)
991                        .expect("bug: singleton ref node must belong to a subgraph");
992                    let dst_sg = self
993                        .node_subgraph(dst_id)
994                        .expect("bug: singleton ref consumer must belong to a subgraph");
995                    if src_sg != dst_sg {
996                        sg_preds.entry(dst_sg).unwrap().or_default().push(src_sg);
997                    }
998                }
999            }
1000
1001            let topo_sort = super::graph_algorithms::topo_sort(self.subgraph_ids(), |sg_id| {
1002                sg_preds.get(sg_id).into_iter().flatten().copied()
1003            })
1004            .expect("bug: unexpected cycle between subgraphs within the tick");
1005
1006            topo_sort
1007                .into_iter()
1008                .map(|sg_id| (sg_id, self.subgraph(sg_id)))
1009                .collect::<Vec<_>>()
1010        };
1011
1012        // Generate swap code for tick-boundary (defer_tick / defer_tick_lazy) handoffs.
1013        // At the start of each tick, swap the main buffer and back buffer so the
1014        // consumer reads last tick's data from the back buffer.
1015        let back_edge_swap_code: Vec<TokenStream> = back_edge_hoff_ids
1016            .iter()
1017            .map(|&hoff_id| {
1018                let span = self.nodes[hoff_id].span();
1019                let buf_ident = self.hoff_buf_ident(hoff_id, span);
1020                let back_ident = self.hoff_back_ident(hoff_id, span);
1021                quote_spanned! {span=>
1022                    ::std::mem::swap(&mut #buf_ident, &mut #back_ident);
1023                }
1024            })
1025            .collect();
1026
1027        let mut op_prologue_code = Vec::new();
1028        let mut op_tick_end_code = Vec::new();
1029        let mut subgraph_blocks = Vec::new();
1030        {
1031            for &(subgraph_id, subgraph_nodes) in all_subgraphs.iter() {
1032                let sg_metrics_ffi = subgraph_id.data().as_ffi();
1033                let (recv_hoffs, send_hoffs) = &subgraph_handoffs[subgraph_id];
1034
1035                // Generate buffer ident helpers for this subgraph's handoffs.
1036                let recv_port_idents: Vec<Ident> = recv_hoffs
1037                    .iter()
1038                    .map(|&hoff_id| self.node_as_ident(hoff_id, true))
1039                    .collect();
1040                let send_port_idents: Vec<Ident> = send_hoffs
1041                    .iter()
1042                    .map(|&hoff_id| self.node_as_ident(hoff_id, false))
1043                    .collect();
1044
1045                // Map handoff node IDs to buffer idents.
1046                let recv_buf_idents: Vec<Ident> = recv_hoffs
1047                    .iter()
1048                    .map(|&hoff_id| self.hoff_buf_ident(hoff_id, self.nodes[hoff_id].span()))
1049                    .collect();
1050                let send_buf_idents: Vec<Ident> = send_hoffs
1051                    .iter()
1052                    .map(|&hoff_id| self.hoff_buf_ident(hoff_id, self.nodes[hoff_id].span()))
1053                    .collect();
1054
1055                // Recv port code: drain from buffer into iterator, tracking if non-empty.
1056                // For back-edge (defer_tick) handoffs, drain from the back buffer instead.
1057                // Also update handoff metrics (measured at recv, not send — see graph.rs).
1058                let recv_port_code: Vec<TokenStream> = recv_port_idents
1059                    .iter()
1060                    .zip(recv_buf_idents.iter())
1061                    .zip(recv_hoffs.iter())
1062                    .map(|((port_ident, buf_ident), &hoff_id)| {
1063                        let hoff_ffi = hoff_id.data().as_ffi();
1064                        // Use call_site span for internal identifiers to avoid
1065                        // hygiene issues when invoked through declarative macros
1066                        // (e.g. dfir_expect_warnings!). TODO(#2781): define these once.
1067                        let work_done = Ident::new("__dfir_work_done", Span::call_site());
1068                        let metrics = Ident::new("__dfir_metrics", Span::call_site());
1069                        // Tick-boundary handoffs drain from the back buffer (double-buffering).
1070                        // (Sending always writes to the regular buffer — no branch needed there.)
1071                        let drain_ident = if back_edge_hoff_ids.contains(&hoff_id) {
1072                            self.hoff_back_ident(hoff_id, buf_ident.span())
1073                        } else {
1074                            buf_ident.clone()
1075                        };
1076                        quote_spanned! {port_ident.span()=>
1077                            {
1078                                let hoff_len = #drain_ident.len();
1079                                if hoff_len > 0 {
1080                                    #work_done = true;
1081                                }
1082                                let hoff_metrics = &#metrics.handoffs[
1083                                    #root::slotmap::KeyData::from_ffi(#hoff_ffi).into()
1084                                ];
1085                                hoff_metrics.total_items_count.update(|x| x + hoff_len);
1086                                hoff_metrics.curr_items_count.set(hoff_len);
1087                            }
1088                            let #port_ident = #root::dfir_pipes::pull::iter(#drain_ident.drain(..));
1089                        }
1090                    })
1091                    .collect();
1092
1093                // Send port code: push into buffer.
1094                let send_port_code: Vec<TokenStream> = send_port_idents
1095                    .iter()
1096                    .zip(send_buf_idents.iter())
1097                    .map(|(port_ident, buf_ident)| {
1098                        quote_spanned! {port_ident.span()=>
1099                            let #port_ident = #root::dfir_pipes::push::vec_push(&mut #buf_ident);
1100                        }
1101                    })
1102                    .collect();
1103
1104                // All nodes in a subgraph should be in the same loop.
1105                let loop_id = self.node_loop(subgraph_nodes[0]);
1106
1107                let mut subgraph_op_iter_code = Vec::new();
1108                let mut subgraph_op_iter_after_code = Vec::new();
1109                {
1110                    let pull_to_push_idx = self.find_pull_to_push_idx(subgraph_nodes);
1111
1112                    let (pull_half, push_half) = subgraph_nodes.split_at(pull_to_push_idx);
1113                    let nodes_iter = pull_half.iter().chain(push_half.iter().rev());
1114
1115                    for (idx, &node_id) in nodes_iter.enumerate() {
1116                        let node = &self.nodes[node_id];
1117                        assert!(
1118                            matches!(node, GraphNode::Operator(_)),
1119                            "Handoffs are not part of subgraphs."
1120                        );
1121                        let op_inst = &self.operator_instances[node_id];
1122
1123                        let op_span = node.span();
1124                        let op_name = op_inst.op_constraints.name;
1125                        // Use op's span for root. #root is expected to be correct, any errors should span back to the op gen.
1126                        let root = change_spans(root.clone(), op_span);
1127                        let op_constraints = OPERATORS
1128                            .iter()
1129                            .find(|op| op_name == op.name)
1130                            .unwrap_or_else(|| panic!("Failed to find op: {}", op_name));
1131
1132                        let ident = self.node_as_ident(node_id, false);
1133
1134                        {
1135                            // TODO clean this up.
1136                            // Collect input arguments (predecessors).
1137                            let mut input_edges = self
1138                                .graph
1139                                .predecessor_edges(node_id)
1140                                .map(|edge_id| (self.edge_ports(edge_id).1, edge_id))
1141                                .collect::<Vec<_>>();
1142                            // Ensure sorted by port index.
1143                            input_edges.sort();
1144
1145                            let inputs = input_edges
1146                                .iter()
1147                                .map(|&(_port, edge_id)| {
1148                                    let (pred, _) = self.edge(edge_id);
1149                                    self.node_as_ident(pred, true)
1150                                })
1151                                .collect::<Vec<_>>();
1152
1153                            // Collect output arguments (successors).
1154                            let mut output_edges = self
1155                                .graph
1156                                .successor_edges(node_id)
1157                                .map(|edge_id| (&self.ports[edge_id].0, edge_id))
1158                                .collect::<Vec<_>>();
1159                            // Ensure sorted by port index.
1160                            output_edges.sort();
1161
1162                            let outputs = output_edges
1163                                .iter()
1164                                .map(|&(_port, edge_id)| {
1165                                    let (_, succ) = self.edge(edge_id);
1166                                    self.node_as_ident(succ, false)
1167                                })
1168                                .collect::<Vec<_>>();
1169
1170                            let is_pull = idx < pull_to_push_idx;
1171
1172                            let singleton_output_ident = &if op_constraints.has_singleton_output {
1173                                self.node_as_singleton_ident(node_id, op_span)
1174                            } else {
1175                                // This ident *should* go unused.
1176                                Ident::new(&format!("{}_has_no_singleton_output", op_name), op_span)
1177                            };
1178
1179                            // There's a bit of dark magic hidden in `Span`s... you'd think it's just a `file:line:column`,
1180                            // but it has one extra bit of info for _name resolution_, used for `Ident`s. `Span::call_site()`
1181                            // has the (unhygienic) resolution we want, an ident is just solely determined by its string name,
1182                            // which is what you'd expect out of unhygienic proc macros like this. Meanwhile, declarative macros
1183                            // use `Span::mixed_site()` which is weird and I don't understand it. It turns out that if you call
1184                            // the dfir syntax proc macro from _within_ a declarative macro then `op_span` will have the
1185                            // bad `Span::mixed_site()` name resolution and cause "Cannot find value `df/context`" errors. So
1186                            // we call `.resolved_at()` to fix resolution back to `Span::call_site()`. -Mingwei
1187                            let df_local = &Ident::new(GRAPH, op_span.resolved_at(df.span()));
1188                            let context = &Ident::new(CONTEXT, op_span.resolved_at(context.span()));
1189
1190                            let singletons_resolved =
1191                                self.helper_resolve_singletons(node_id, op_span);
1192                            let arguments = &process_singletons::postprocess_singletons(
1193                                op_inst.arguments_raw.clone(),
1194                                singletons_resolved.clone(),
1195                            );
1196                            let arguments_handles =
1197                                &process_singletons::postprocess_singletons_handles(
1198                                    op_inst.arguments_raw.clone(),
1199                                    singletons_resolved.clone(),
1200                                );
1201
1202                            let source_tag = 'a: {
1203                                if let Some(tag) = self.operator_tag.get(node_id).cloned() {
1204                                    break 'a tag;
1205                                }
1206
1207                                #[cfg(nightly)]
1208                                if proc_macro::is_available() {
1209                                    let op_span = op_span.unwrap();
1210                                    break 'a format!(
1211                                        "loc_{}_{}_{}_{}_{}",
1212                                        crate::pretty_span::make_source_path_relative(
1213                                            &op_span.file()
1214                                        )
1215                                        .display()
1216                                        .to_string()
1217                                        .replace(|x: char| !x.is_ascii_alphanumeric(), "_"),
1218                                        op_span.start().line(),
1219                                        op_span.start().column(),
1220                                        op_span.end().line(),
1221                                        op_span.end().column(),
1222                                    );
1223                                }
1224
1225                                format!(
1226                                    "loc_nopath_{}_{}_{}_{}",
1227                                    op_span.start().line,
1228                                    op_span.start().column,
1229                                    op_span.end().line,
1230                                    op_span.end().column
1231                                )
1232                            };
1233
1234                            let work_fn = format_ident!(
1235                                "{}__{}__{}",
1236                                ident,
1237                                op_name,
1238                                source_tag,
1239                                span = op_span
1240                            );
1241                            let work_fn_async = format_ident!("{}__async", work_fn, span = op_span);
1242
1243                            let context_args = WriteContextArgs {
1244                                root: &root,
1245                                df_ident: df_local,
1246                                context,
1247                                subgraph_id,
1248                                node_id,
1249                                loop_id,
1250                                op_span,
1251                                op_tag: self.operator_tag.get(node_id).cloned(),
1252                                work_fn: &work_fn,
1253                                work_fn_async: &work_fn_async,
1254                                ident: &ident,
1255                                is_pull,
1256                                inputs: &inputs,
1257                                outputs: &outputs,
1258                                singleton_output_ident,
1259                                op_name,
1260                                op_inst,
1261                                arguments,
1262                                arguments_handles,
1263                            };
1264
1265                            let write_result =
1266                                (op_constraints.write_fn)(&context_args, diagnostics);
1267                            let OperatorWriteOutput {
1268                                write_prologue,
1269                                write_iterator,
1270                                write_iterator_after,
1271                                write_tick_end,
1272                            } = write_result.unwrap_or_else(|()| {
1273                                assert!(
1274                                    diagnostics.has_error(),
1275                                    "Operator `{}` returned `Err` but emitted no diagnostics, this is a bug.",
1276                                    op_name,
1277                                );
1278                                OperatorWriteOutput {
1279                                    write_iterator: null_write_iterator_fn(&context_args),
1280                                    ..Default::default()
1281                                }
1282                            });
1283
1284                            op_prologue_code.push(syn::parse_quote! {
1285                                #[allow(non_snake_case)]
1286                                #[inline(always)]
1287                                fn #work_fn<T>(thunk: impl ::std::ops::FnOnce() -> T) -> T {
1288                                    thunk()
1289                                }
1290
1291                                #[allow(non_snake_case)]
1292                                #[inline(always)]
1293                                async fn #work_fn_async<T>(
1294                                    thunk: impl ::std::future::Future<Output = T>,
1295                                ) -> T {
1296                                    thunk.await
1297                                }
1298                            });
1299                            op_prologue_code.push(write_prologue);
1300                            op_tick_end_code.push(write_tick_end);
1301                            subgraph_op_iter_code.push(write_iterator);
1302
1303                            if include_type_guards {
1304                                let type_guard = if is_pull {
1305                                    quote_spanned! {op_span=>
1306                                        let #ident = {
1307                                            #[allow(non_snake_case)]
1308                                            #[inline(always)]
1309                                            pub fn #work_fn<Item, Input>(input: Input)
1310                                                -> impl #root::dfir_pipes::pull::Pull<Item = Item, Meta = (), CanPend = Input::CanPend, CanEnd = Input::CanEnd>
1311                                            where
1312                                                Input: #root::dfir_pipes::pull::Pull<Item = Item, Meta = ()>,
1313                                            {
1314                                                #root::pin_project_lite::pin_project! {
1315                                                    #[repr(transparent)]
1316                                                    struct Pull<Item, Input: #root::dfir_pipes::pull::Pull<Item = Item>> {
1317                                                        #[pin]
1318                                                        inner: Input
1319                                                    }
1320                                                }
1321
1322                                                impl<Item, Input> #root::dfir_pipes::pull::Pull for Pull<Item, Input>
1323                                                where
1324                                                    Input: #root::dfir_pipes::pull::Pull<Item = Item>,
1325                                                {
1326                                                    type Ctx<'ctx> = Input::Ctx<'ctx>;
1327
1328                                                    type Item = Item;
1329                                                    type Meta = Input::Meta;
1330                                                    type CanPend = Input::CanPend;
1331                                                    type CanEnd = Input::CanEnd;
1332
1333                                                    #[inline(always)]
1334                                                    fn pull(
1335                                                        self: ::std::pin::Pin<&mut Self>,
1336                                                        ctx: &mut Self::Ctx<'_>,
1337                                                    ) -> #root::dfir_pipes::pull::PullStep<Self::Item, Self::Meta, Self::CanPend, Self::CanEnd> {
1338                                                        #root::dfir_pipes::pull::Pull::pull(self.project().inner, ctx)
1339                                                    }
1340
1341                                                    #[inline(always)]
1342                                                    fn size_hint(&self) -> (usize, Option<usize>) {
1343                                                        #root::dfir_pipes::pull::Pull::size_hint(&self.inner)
1344                                                    }
1345                                                }
1346
1347                                                Pull {
1348                                                    inner: input
1349                                                }
1350                                            }
1351                                            #work_fn::<_, _>( #ident )
1352                                        };
1353                                    }
1354                                } else {
1355                                    quote_spanned! {op_span=>
1356                                        let #ident = {
1357                                            #[allow(non_snake_case)]
1358                                            #[inline(always)]
1359                                            pub fn #work_fn<Item, Psh>(psh: Psh) -> impl #root::dfir_pipes::push::Push<Item, (), CanPend = Psh::CanPend>
1360                                            where
1361                                                Psh: #root::dfir_pipes::push::Push<Item, ()>
1362                                            {
1363                                                #root::pin_project_lite::pin_project! {
1364                                                    #[repr(transparent)]
1365                                                    struct PushGuard<Psh> {
1366                                                        #[pin]
1367                                                        inner: Psh,
1368                                                    }
1369                                                }
1370
1371                                                impl<Item, Psh> #root::dfir_pipes::push::Push<Item, ()> for PushGuard<Psh>
1372                                                where
1373                                                    Psh: #root::dfir_pipes::push::Push<Item, ()>,
1374                                                {
1375                                                    type Ctx<'ctx> = Psh::Ctx<'ctx>;
1376
1377                                                    type CanPend = Psh::CanPend;
1378
1379                                                    #[inline(always)]
1380                                                    fn poll_ready(
1381                                                        self: ::std::pin::Pin<&mut Self>,
1382                                                        ctx: &mut Self::Ctx<'_>,
1383                                                    ) -> #root::dfir_pipes::push::PushStep<Self::CanPend> {
1384                                                        #root::dfir_pipes::push::Push::poll_ready(self.project().inner, ctx)
1385                                                    }
1386
1387                                                    #[inline(always)]
1388                                                    fn start_send(
1389                                                        self: ::std::pin::Pin<&mut Self>,
1390                                                        item: Item,
1391                                                        meta: (),
1392                                                    ) {
1393                                                        #root::dfir_pipes::push::Push::start_send(self.project().inner, item, meta)
1394                                                    }
1395
1396                                                    #[inline(always)]
1397                                                    fn poll_flush(
1398                                                        self: ::std::pin::Pin<&mut Self>,
1399                                                        ctx: &mut Self::Ctx<'_>,
1400                                                    ) -> #root::dfir_pipes::push::PushStep<Self::CanPend> {
1401                                                        #root::dfir_pipes::push::Push::poll_flush(self.project().inner, ctx)
1402                                                    }
1403
1404                                                    #[inline(always)]
1405                                                    fn size_hint(
1406                                                        self: ::std::pin::Pin<&mut Self>,
1407                                                        hint: (usize, Option<usize>),
1408                                                    ) {
1409                                                        #root::dfir_pipes::push::Push::size_hint(self.project().inner, hint)
1410                                                    }
1411                                                }
1412
1413                                                PushGuard {
1414                                                    inner: psh
1415                                                }
1416                                            }
1417                                            #work_fn( #ident )
1418                                        };
1419                                    }
1420                                };
1421                                subgraph_op_iter_code.push(type_guard);
1422                            }
1423                            subgraph_op_iter_after_code.push(write_iterator_after);
1424                        }
1425                    }
1426
1427                    {
1428                        // Determine pull and push halves of the `Pivot`.
1429                        let pull_ident = if 0 < pull_to_push_idx {
1430                            self.node_as_ident(subgraph_nodes[pull_to_push_idx - 1], false)
1431                        } else {
1432                            // Entire subgraph is push (with a single recv/pull handoff input).
1433                            recv_port_idents[0].clone()
1434                        };
1435
1436                        #[rustfmt::skip]
1437                        let push_ident = if let Some(&node_id) =
1438                            subgraph_nodes.get(pull_to_push_idx)
1439                        {
1440                            self.node_as_ident(node_id, false)
1441                        } else if 1 == send_port_idents.len() {
1442                            // Entire subgraph is pull (with a single send/push handoff output).
1443                            send_port_idents[0].clone()
1444                        } else {
1445                            diagnostics.push(Diagnostic::spanned(
1446                                pull_ident.span(),
1447                                Level::Error,
1448                                "Degenerate subgraph detected, is there a disconnected `null()` or other degenerate pipeline somewhere?",
1449                            ));
1450                            continue;
1451                        };
1452
1453                        // Pivot span is combination of pull and push spans (or if not possible, just take the push).
1454                        let pivot_span = pull_ident
1455                            .span()
1456                            .join(push_ident.span())
1457                            .unwrap_or_else(|| push_ident.span());
1458                        let pivot_fn_ident = Ident::new(
1459                            &format!("pivot_run_sg_{:?}", subgraph_id.data()),
1460                            pivot_span,
1461                        );
1462                        let root = change_spans(root.clone(), pivot_span);
1463                        subgraph_op_iter_code.push(quote_spanned! {pivot_span=>
1464                            #[inline(always)]
1465                            fn #pivot_fn_ident<Pul, Psh, Item>(pull: Pul, push: Psh)
1466                                -> impl ::std::future::Future<Output = ()>
1467                            where
1468                                Pul: #root::dfir_pipes::pull::Pull<Item = Item>,
1469                                Psh: #root::dfir_pipes::push::Push<Item, Pul::Meta>,
1470                            {
1471                                #root::dfir_pipes::pull::Pull::send_push(pull, push)
1472                            }
1473                            (#pivot_fn_ident)(#pull_ident, #push_ident).await;
1474                        });
1475                    }
1476                };
1477
1478                // Each subgraph block is an async block so it can be individually instrumented.
1479                // Note: this ident is for the subgraph future, not a runtime SubgraphId binding
1480                // (unlike the scheduled path's `sg_ident`).
1481                let sg_fut_ident = subgraph_id.as_ident(Span::call_site());
1482
1483                // Generate send-side curr_items_count updates (after subgraph runs).
1484                let send_metrics_code: Vec<TokenStream> = send_hoffs
1485                    .iter()
1486                    .zip(send_buf_idents.iter())
1487                    .map(|(&hoff_id, buf_ident)| {
1488                        let hoff_ffi = hoff_id.data().as_ffi();
1489                        quote! {
1490                            __dfir_metrics.handoffs[
1491                                #root::slotmap::KeyData::from_ffi(#hoff_ffi).into()
1492                            ].curr_items_count.set(#buf_ident.len());
1493                        }
1494                    })
1495                    .collect();
1496
1497                subgraph_blocks.push(quote! {
1498                    let #sg_fut_ident = async {
1499                        let #context = &#df;
1500                        #( #recv_port_code )*
1501                        #( #send_port_code )*
1502                        #( #subgraph_op_iter_code )*
1503                        #( #subgraph_op_iter_after_code )*
1504                    };
1505                    {
1506                        let sg_metrics = &__dfir_metrics.subgraphs[
1507                            #root::slotmap::KeyData::from_ffi(#sg_metrics_ffi).into()
1508                        ];
1509                        #root::scheduled::metrics::InstrumentSubgraph::new(
1510                            #sg_fut_ident, sg_metrics
1511                        ).await;
1512                        sg_metrics.total_run_count.update(|x| x + 1);
1513                    }
1514                    #( #send_metrics_code )*
1515                });
1516
1517                // Collect per-subgraph prologues into the main prologue lists.
1518                // (They are already pushed above in the operator loop.)
1519            }
1520        }
1521
1522        if diagnostics.has_error() {
1523            return Err(std::mem::take(diagnostics));
1524        }
1525        let _ = diagnostics; // Ensure no more diagnostics may be added after checking for errors.
1526
1527        let (meta_graph_arg, diagnostics_arg) = if include_meta {
1528            let meta_graph_json = serde_json::to_string(&self).unwrap();
1529            let meta_graph_json = Literal::string(&meta_graph_json);
1530
1531            let serde_diagnostics: Vec<_> = diagnostics.iter().map(Diagnostic::to_serde).collect();
1532            let diagnostics_json = serde_json::to_string(&*serde_diagnostics).unwrap();
1533            let diagnostics_json = Literal::string(&diagnostics_json);
1534
1535            (
1536                quote! { Some(#meta_graph_json) },
1537                quote! { Some(#diagnostics_json) },
1538            )
1539        } else {
1540            (quote! { None }, quote! { None })
1541        };
1542
1543        // Generate metrics initialization: one entry per handoff and per subgraph.
1544        let metrics_init_code = {
1545            let handoff_inits = handoff_nodes.iter().map(|&(node_id, _)| {
1546                let ffi = node_id.data().as_ffi();
1547                quote! {
1548                    dfir_metrics.handoffs.insert(
1549                        #root::slotmap::KeyData::from_ffi(#ffi).into(),
1550                        ::std::default::Default::default(),
1551                    );
1552                }
1553            });
1554            let subgraph_inits = all_subgraphs.iter().map(|&(sg_id, _)| {
1555                let ffi = sg_id.data().as_ffi();
1556                quote! {
1557                    dfir_metrics.subgraphs.insert(
1558                        #root::slotmap::KeyData::from_ffi(#ffi).into(),
1559                        ::std::default::Default::default(),
1560                    );
1561                }
1562            });
1563            handoff_inits.chain(subgraph_inits).collect::<Vec<_>>()
1564        };
1565
1566        // Prologues and buffer declarations persist across ticks (outside the closure).
1567        // Subgraph blocks run each tick (inside the closure).
1568        Ok(quote! {
1569            {
1570                #prefix
1571
1572                use #root::{var_expr, var_args};
1573
1574                let __dfir_wake_state = ::std::sync::Arc::new(
1575                    #root::scheduled::context::WakeState::default()
1576                );
1577
1578                let __dfir_metrics = {
1579                    let mut dfir_metrics = #root::scheduled::metrics::DfirMetrics::default();
1580                    #( #metrics_init_code )*
1581                    ::std::rc::Rc::new(dfir_metrics)
1582                };
1583
1584                #[allow(unused_mut)]
1585                let mut #df = #root::scheduled::context::Context::new(
1586                    ::std::clone::Clone::clone(&__dfir_wake_state),
1587                    __dfir_metrics,
1588                );
1589
1590                #( #buffer_code )*
1591                #( #back_buffer_code )*
1592                #( #op_prologue_code )*
1593
1594                // Pre-set to true so the first tick always returns true
1595                // (matching Dfir pre-scheduling behavior). Subsequent ticks
1596                // start false (from take()) and are set true by recv port code
1597                // if any handoff buffer has data.
1598                let mut __dfir_work_done = true;
1599                #[allow(unused_qualifications, unused_mut, unused_variables, clippy::await_holding_refcell_ref, clippy::deref_addrof)]
1600                let __dfir_inline_tick = async move |#df: &mut #root::scheduled::context::Context| {
1601                    let __dfir_metrics = #df.metrics();
1602                    // Double-buffer swap for defer_tick handoffs: move last tick's
1603                    // producer output into the back buffer for the consumer to drain.
1604                    #( #back_edge_swap_code )*
1605                    #( #subgraph_blocks )*
1606
1607                    // For non-lazy defer_tick: if any deferred buffer has data,
1608                    // signal that another tick should run.
1609                    if false #( || !#defer_tick_buf_idents.is_empty() )* {
1610                        #df.schedule_subgraph(true);
1611                    }
1612
1613                    // End-of-tick state reset (e.g. 'tick persistence).
1614                    #( #op_tick_end_code )*
1615
1616                    #df.__end_tick();
1617                    ::std::mem::take(&mut __dfir_work_done)
1618                };
1619                #root::scheduled::context::Dfir::new(
1620                    __dfir_inline_tick,
1621                    #df,
1622                    #meta_graph_arg,
1623                    #diagnostics_arg,
1624                )
1625            }
1626        })
1627    }
1628
1629    /// Color mode (pull vs. push, handoff vs. comp) for nodes. Some nodes can be push *OR* pull;
1630    /// those nodes will not be set in the returned map.
1631    pub fn node_color_map(&self) -> SparseSecondaryMap<GraphNodeId, Color> {
1632        let mut node_color_map: SparseSecondaryMap<GraphNodeId, Color> = self
1633            .node_ids()
1634            .filter_map(|node_id| {
1635                let op_color = self.node_color(node_id)?;
1636                Some((node_id, op_color))
1637            })
1638            .collect();
1639
1640        // Fill in rest via subgraphs.
1641        for sg_nodes in self.subgraph_nodes.values() {
1642            let pull_to_push_idx = self.find_pull_to_push_idx(sg_nodes);
1643
1644            for (idx, node_id) in sg_nodes.iter().copied().enumerate() {
1645                let is_pull = idx < pull_to_push_idx;
1646                node_color_map.insert(node_id, if is_pull { Color::Pull } else { Color::Push });
1647            }
1648        }
1649
1650        node_color_map
1651    }
1652
1653    /// Writes this graph as mermaid into a string.
1654    pub fn to_mermaid(&self, write_config: &WriteConfig) -> String {
1655        let mut output = String::new();
1656        self.write_mermaid(&mut output, write_config).unwrap();
1657        output
1658    }
1659
1660    /// Writes this graph as mermaid into the given `Write`.
1661    pub fn write_mermaid(
1662        &self,
1663        output: impl std::fmt::Write,
1664        write_config: &WriteConfig,
1665    ) -> std::fmt::Result {
1666        let mut graph_write = Mermaid::new(output);
1667        self.write_graph(&mut graph_write, write_config)
1668    }
1669
1670    /// Writes this graph as DOT (graphviz) into a string.
1671    pub fn to_dot(&self, write_config: &WriteConfig) -> String {
1672        let mut output = String::new();
1673        let mut graph_write = Dot::new(&mut output);
1674        self.write_graph(&mut graph_write, write_config).unwrap();
1675        output
1676    }
1677
1678    /// Writes this graph as DOT (graphviz) into the given `Write`.
1679    pub fn write_dot(
1680        &self,
1681        output: impl std::fmt::Write,
1682        write_config: &WriteConfig,
1683    ) -> std::fmt::Result {
1684        let mut graph_write = Dot::new(output);
1685        self.write_graph(&mut graph_write, write_config)
1686    }
1687
1688    /// Write out this graph using the given `GraphWrite`. E.g. `Mermaid` or `Dot.
1689    pub(crate) fn write_graph<W>(
1690        &self,
1691        mut graph_write: W,
1692        write_config: &WriteConfig,
1693    ) -> Result<(), W::Err>
1694    where
1695        W: GraphWrite,
1696    {
1697        fn helper_edge_label(
1698            src_port: &PortIndexValue,
1699            dst_port: &PortIndexValue,
1700        ) -> Option<String> {
1701            let src_label = match src_port {
1702                PortIndexValue::Path(path) => Some(path.to_token_stream().to_string()),
1703                PortIndexValue::Int(index) => Some(index.value.to_string()),
1704                _ => None,
1705            };
1706            let dst_label = match dst_port {
1707                PortIndexValue::Path(path) => Some(path.to_token_stream().to_string()),
1708                PortIndexValue::Int(index) => Some(index.value.to_string()),
1709                _ => None,
1710            };
1711            let label = match (src_label, dst_label) {
1712                (Some(l1), Some(l2)) => Some(format!("{}\n{}", l1, l2)),
1713                (Some(l1), None) => Some(l1),
1714                (None, Some(l2)) => Some(l2),
1715                (None, None) => None,
1716            };
1717            label
1718        }
1719
1720        // Make node color map one time.
1721        let node_color_map = self.node_color_map();
1722
1723        // Write prologue.
1724        graph_write.write_prologue()?;
1725
1726        // Define nodes.
1727        let mut skipped_handoffs = BTreeSet::new();
1728        let mut subgraph_handoffs = <BTreeMap<GraphSubgraphId, Vec<GraphNodeId>>>::new();
1729        for (node_id, node) in self.nodes() {
1730            if matches!(node, GraphNode::Handoff { .. }) {
1731                if write_config.no_handoffs {
1732                    skipped_handoffs.insert(node_id);
1733                    continue;
1734                } else {
1735                    let pred_node = self.node_predecessor_nodes(node_id).next().unwrap();
1736                    let pred_sg = self.node_subgraph(pred_node);
1737                    let succ_node = self.node_successor_nodes(node_id).next().unwrap();
1738                    let succ_sg = self.node_subgraph(succ_node);
1739                    if let Some((pred_sg, succ_sg)) = pred_sg.zip(succ_sg)
1740                        && pred_sg == succ_sg
1741                    {
1742                        subgraph_handoffs.entry(pred_sg).or_default().push(node_id);
1743                    }
1744                }
1745            }
1746            graph_write.write_node_definition(
1747                node_id,
1748                &if write_config.op_short_text {
1749                    node.to_name_string()
1750                } else if write_config.op_text_no_imports {
1751                    // Remove any lines that start with "use" (imports)
1752                    let full_text = node.to_pretty_string();
1753                    let mut output = String::new();
1754                    for sentence in full_text.split('\n') {
1755                        if sentence.trim().starts_with("use") {
1756                            continue;
1757                        }
1758                        output.push('\n');
1759                        output.push_str(sentence);
1760                    }
1761                    output.into()
1762                } else {
1763                    node.to_pretty_string()
1764                },
1765                if write_config.no_pull_push {
1766                    None
1767                } else {
1768                    node_color_map.get(node_id).copied()
1769                },
1770            )?;
1771        }
1772
1773        // Write edges.
1774        for (edge_id, (src_id, mut dst_id)) in self.edges() {
1775            // Handling for if `write_config.no_handoffs` true.
1776            if skipped_handoffs.contains(&src_id) {
1777                continue;
1778            }
1779
1780            let (src_port, mut dst_port) = self.edge_ports(edge_id);
1781            if skipped_handoffs.contains(&dst_id) {
1782                let mut handoff_succs = self.node_successors(dst_id);
1783                assert_eq!(1, handoff_succs.len());
1784                let (succ_edge, succ_node) = handoff_succs.next().unwrap();
1785                dst_id = succ_node;
1786                dst_port = self.edge_ports(succ_edge).1;
1787            }
1788
1789            let label = helper_edge_label(src_port, dst_port);
1790            let delay_type = self
1791                .node_op_inst(dst_id)
1792                .and_then(|op_inst| (op_inst.op_constraints.input_delaytype_fn)(dst_port));
1793            graph_write.write_edge(src_id, dst_id, delay_type, label.as_deref(), false)?;
1794        }
1795
1796        // Write reference edges.
1797        if !write_config.no_references {
1798            for dst_id in self.node_ids() {
1799                for src_ref_id in self
1800                    .node_singleton_references(dst_id)
1801                    .iter()
1802                    .copied()
1803                    .flatten()
1804                {
1805                    let delay_type = None;
1806                    let label = None;
1807                    graph_write.write_edge(src_ref_id, dst_id, delay_type, label, true)?;
1808                }
1809            }
1810        }
1811
1812        // The following code is a little bit tricky. Generally, the graph has the hierarchy:
1813        // `loop -> subgraph -> varname -> node`. However, each of these can be disabled via the `write_config`. To
1814        // handle both the enabled and disabled case, this code is structured as a series of nested loops. If the layer
1815        // is disabled, then the HashMap<Option<KEY>, Vec<VALUE>> will only have a single key (`None`) with a
1816        // corresponding `Vec` value containing everything. This way no special handling is needed for the next layer.
1817
1818        // Loop -> Subgraphs
1819        let loop_subgraphs = self.subgraph_ids().map(|sg_id| {
1820            let loop_id = if write_config.no_loops {
1821                None
1822            } else {
1823                self.subgraph_loop(sg_id)
1824            };
1825            (loop_id, sg_id)
1826        });
1827        let loop_subgraphs = into_group_map(loop_subgraphs);
1828        for (loop_id, subgraph_ids) in loop_subgraphs {
1829            if let Some(loop_id) = loop_id {
1830                graph_write.write_loop_start(loop_id)?;
1831            }
1832
1833            // Subgraph -> Varnames.
1834            let subgraph_varnames_nodes = subgraph_ids.into_iter().flat_map(|sg_id| {
1835                self.subgraph(sg_id).iter().copied().map(move |node_id| {
1836                    let opt_sg_id = if write_config.no_subgraphs {
1837                        None
1838                    } else {
1839                        Some(sg_id)
1840                    };
1841                    (opt_sg_id, (self.node_varname(node_id), node_id))
1842                })
1843            });
1844            let subgraph_varnames_nodes = into_group_map(subgraph_varnames_nodes);
1845            for (sg_id, varnames) in subgraph_varnames_nodes {
1846                if let Some(sg_id) = sg_id {
1847                    graph_write.write_subgraph_start(sg_id)?;
1848                }
1849
1850                // Varnames -> Nodes.
1851                let varname_nodes = varnames.into_iter().map(|(varname, node)| {
1852                    let varname = if write_config.no_varnames {
1853                        None
1854                    } else {
1855                        varname
1856                    };
1857                    (varname, node)
1858                });
1859                let varname_nodes = into_group_map(varname_nodes);
1860                for (varname, node_ids) in varname_nodes {
1861                    if let Some(varname) = varname {
1862                        graph_write.write_varname_start(&varname.0.to_string(), sg_id)?;
1863                    }
1864
1865                    // Write all nodes.
1866                    for node_id in node_ids {
1867                        graph_write.write_node(node_id)?;
1868                    }
1869
1870                    if varname.is_some() {
1871                        graph_write.write_varname_end()?;
1872                    }
1873                }
1874
1875                if sg_id.is_some() {
1876                    graph_write.write_subgraph_end()?;
1877                }
1878            }
1879
1880            if loop_id.is_some() {
1881                graph_write.write_loop_end()?;
1882            }
1883        }
1884
1885        // Write epilogue.
1886        graph_write.write_epilogue()?;
1887
1888        Ok(())
1889    }
1890
1891    /// Convert back into surface syntax.
1892    pub fn surface_syntax_string(&self) -> String {
1893        let mut string = String::new();
1894        self.write_surface_syntax(&mut string).unwrap();
1895        string
1896    }
1897
1898    /// Convert back into surface syntax.
1899    pub fn write_surface_syntax(&self, write: &mut impl std::fmt::Write) -> std::fmt::Result {
1900        for (key, node) in self.nodes.iter() {
1901            match node {
1902                GraphNode::Operator(op) => {
1903                    writeln!(write, "{:?} = {};", key.data(), op.to_token_stream())?;
1904                }
1905                GraphNode::Handoff { .. } => {
1906                    writeln!(write, "// {:?} = <handoff>;", key.data())?;
1907                }
1908                GraphNode::ModuleBoundary { .. } => panic!(),
1909            }
1910        }
1911        writeln!(write)?;
1912        for (_e, (src_key, dst_key)) in self.graph.edges() {
1913            writeln!(write, "{:?} -> {:?};", src_key.data(), dst_key.data())?;
1914        }
1915        Ok(())
1916    }
1917
1918    /// Convert into a [mermaid](https://mermaid-js.github.io/) graph. Ignores subgraphs.
1919    pub fn mermaid_string_flat(&self) -> String {
1920        let mut string = String::new();
1921        self.write_mermaid_flat(&mut string).unwrap();
1922        string
1923    }
1924
1925    /// Convert into a [mermaid](https://mermaid-js.github.io/) graph. Ignores subgraphs.
1926    pub fn write_mermaid_flat(&self, write: &mut impl std::fmt::Write) -> std::fmt::Result {
1927        writeln!(write, "flowchart TB")?;
1928        for (key, node) in self.nodes.iter() {
1929            match node {
1930                GraphNode::Operator(operator) => writeln!(
1931                    write,
1932                    "    %% {span}\n    {id:?}[\"{row_col} <tt>{code}</tt>\"]",
1933                    span = PrettySpan(node.span()),
1934                    id = key.data(),
1935                    row_col = PrettyRowCol(node.span()),
1936                    code = operator
1937                        .to_token_stream()
1938                        .to_string()
1939                        .replace('&', "&amp;")
1940                        .replace('<', "&lt;")
1941                        .replace('>', "&gt;")
1942                        .replace('"', "&quot;")
1943                        .replace('\n', "<br>"),
1944                ),
1945                GraphNode::Handoff { .. } => {
1946                    writeln!(write, r#"    {:?}{{"{}"}}"#, key.data(), HANDOFF_NODE_STR)
1947                }
1948                GraphNode::ModuleBoundary { .. } => {
1949                    writeln!(
1950                        write,
1951                        r#"    {:?}{{"{}"}}"#,
1952                        key.data(),
1953                        MODULE_BOUNDARY_NODE_STR
1954                    )
1955                }
1956            }?;
1957        }
1958        writeln!(write)?;
1959        for (_e, (src_key, dst_key)) in self.graph.edges() {
1960            writeln!(write, "    {:?}-->{:?}", src_key.data(), dst_key.data())?;
1961        }
1962        Ok(())
1963    }
1964}
1965
1966/// Loops
1967impl DfirGraph {
1968    /// Iterator over all loop IDs.
1969    pub fn loop_ids(&self) -> slotmap::basic::Keys<'_, GraphLoopId, Vec<GraphNodeId>> {
1970        self.loop_nodes.keys()
1971    }
1972
1973    /// Iterator over all loops, ID and members: `(GraphLoopId, Vec<GraphNodeId>)`.
1974    pub fn loops(&self) -> slotmap::basic::Iter<'_, GraphLoopId, Vec<GraphNodeId>> {
1975        self.loop_nodes.iter()
1976    }
1977
1978    /// Create a new loop context, with the given parent loop (or `None`).
1979    pub fn insert_loop(&mut self, parent_loop: Option<GraphLoopId>) -> GraphLoopId {
1980        let loop_id = self.loop_nodes.insert(Vec::new());
1981        self.loop_children.insert(loop_id, Vec::new());
1982        if let Some(parent_loop) = parent_loop {
1983            self.loop_parent.insert(loop_id, parent_loop);
1984            self.loop_children
1985                .get_mut(parent_loop)
1986                .unwrap()
1987                .push(loop_id);
1988        } else {
1989            self.root_loops.push(loop_id);
1990        }
1991        loop_id
1992    }
1993
1994    /// Get a node's loop context (or `None` for root).
1995    pub fn node_loop(&self, node_id: GraphNodeId) -> Option<GraphLoopId> {
1996        self.node_loops.get(node_id).copied()
1997    }
1998
1999    /// Get a subgraph's loop context (or `None` for root).
2000    pub fn subgraph_loop(&self, subgraph_id: GraphSubgraphId) -> Option<GraphLoopId> {
2001        let &node_id = self.subgraph(subgraph_id).first().unwrap();
2002        let out = self.node_loop(node_id);
2003        debug_assert!(
2004            self.subgraph(subgraph_id)
2005                .iter()
2006                .all(|&node_id| self.node_loop(node_id) == out),
2007            "Subgraph nodes should all have the same loop context."
2008        );
2009        out
2010    }
2011
2012    /// Get a loop context's parent loop context (or `None` for root).
2013    pub fn loop_parent(&self, loop_id: GraphLoopId) -> Option<GraphLoopId> {
2014        self.loop_parent.get(loop_id).copied()
2015    }
2016
2017    /// Get a loop context's child loops.
2018    pub fn loop_children(&self, loop_id: GraphLoopId) -> &Vec<GraphLoopId> {
2019        self.loop_children.get(loop_id).unwrap()
2020    }
2021}
2022
2023/// Configuration for writing graphs.
2024#[derive(Clone, Debug, Default)]
2025#[cfg_attr(feature = "clap-derive", derive(clap::Args))]
2026pub struct WriteConfig {
2027    /// Subgraphs will not be rendered if set.
2028    #[cfg_attr(feature = "clap-derive", arg(long))]
2029    pub no_subgraphs: bool,
2030    /// Variable names will not be rendered if set.
2031    #[cfg_attr(feature = "clap-derive", arg(long))]
2032    pub no_varnames: bool,
2033    /// Will not render pull/push shapes if set.
2034    #[cfg_attr(feature = "clap-derive", arg(long))]
2035    pub no_pull_push: bool,
2036    /// Will not render handoffs if set.
2037    #[cfg_attr(feature = "clap-derive", arg(long))]
2038    pub no_handoffs: bool,
2039    /// Will not render singleton references if set.
2040    #[cfg_attr(feature = "clap-derive", arg(long))]
2041    pub no_references: bool,
2042    /// Will not render loops if set.
2043    #[cfg_attr(feature = "clap-derive", arg(long))]
2044    pub no_loops: bool,
2045
2046    /// Op text will only be their name instead of the whole source.
2047    #[cfg_attr(feature = "clap-derive", arg(long))]
2048    pub op_short_text: bool,
2049    /// Op text will exclude any line that starts with "use".
2050    #[cfg_attr(feature = "clap-derive", arg(long))]
2051    pub op_text_no_imports: bool,
2052}
2053
2054/// Enum for choosing between mermaid and dot graph writing.
2055#[derive(Copy, Clone, Debug)]
2056#[cfg_attr(feature = "clap-derive", derive(clap::Parser, clap::ValueEnum))]
2057pub enum WriteGraphType {
2058    /// Mermaid graphs.
2059    Mermaid,
2060    /// Dot (Graphviz) graphs.
2061    Dot,
2062}
2063
2064/// [`itertools::Itertools::into_group_map`], but for `BTreeMap`.
2065fn into_group_map<K, V>(iter: impl IntoIterator<Item = (K, V)>) -> BTreeMap<K, Vec<V>>
2066where
2067    K: Ord,
2068{
2069    let mut out: BTreeMap<_, Vec<_>> = BTreeMap::new();
2070    for (k, v) in iter {
2071        out.entry(k).or_default().push(v);
2072    }
2073    out
2074}