1#![warn(missing_docs)]
2
3extern crate proc_macro;
4
5use std::collections::{BTreeMap, BTreeSet};
6use std::fmt::Debug;
7use std::iter::FusedIterator;
8
9use itertools::Itertools;
10use proc_macro2::{Ident, Literal, Span, TokenStream};
11use quote::{ToTokens, format_ident, quote, quote_spanned};
12use serde::{Deserialize, Serialize};
13use slotmap::{Key, SecondaryMap, SlotMap, SparseSecondaryMap};
14use syn::spanned::Spanned;
15
16use super::graph_write::{Dot, GraphWrite, Mermaid};
17use super::ops::{
18 DelayType, OPERATORS, OperatorWriteOutput, WriteContextArgs, find_op_op_constraints,
19 null_write_iterator_fn,
20};
21use super::{
22 CONTEXT, Color, DiMulGraph, GRAPH, GraphEdgeId, GraphLoopId, GraphNode, GraphNodeId,
23 GraphSubgraphId, HANDOFF_NODE_STR, MODULE_BOUNDARY_NODE_STR, OperatorInstance, PortIndexValue,
24 Varname, change_spans, get_operator_generics,
25};
26use crate::diagnostic::{Diagnostic, Diagnostics, Level};
27use crate::pretty_span::{PrettyRowCol, PrettySpan};
28use crate::process_singletons;
29
30#[derive(Default, Debug, Serialize, Deserialize)]
40pub struct DfirGraph {
41 nodes: SlotMap<GraphNodeId, GraphNode>,
43
44 #[serde(skip)]
47 operator_instances: SecondaryMap<GraphNodeId, OperatorInstance>,
48 operator_tag: SecondaryMap<GraphNodeId, String>,
50 graph: DiMulGraph<GraphNodeId, GraphEdgeId>,
52 ports: SecondaryMap<GraphEdgeId, (PortIndexValue, PortIndexValue)>,
54
55 node_loops: SecondaryMap<GraphNodeId, GraphLoopId>,
57 loop_nodes: SlotMap<GraphLoopId, Vec<GraphNodeId>>,
59 loop_parent: SparseSecondaryMap<GraphLoopId, GraphLoopId>,
61 root_loops: Vec<GraphLoopId>,
63 loop_children: SecondaryMap<GraphLoopId, Vec<GraphLoopId>>,
65
66 node_subgraph: SecondaryMap<GraphNodeId, GraphSubgraphId>,
68
69 subgraph_nodes: SlotMap<GraphSubgraphId, Vec<GraphNodeId>>,
71
72 node_singleton_references: SparseSecondaryMap<GraphNodeId, Vec<Option<GraphNodeId>>>,
74 node_varnames: SparseSecondaryMap<GraphNodeId, Varname>,
76
77 handoff_delay_type: SparseSecondaryMap<GraphNodeId, DelayType>,
81
82 node_is_singleton: SparseSecondaryMap<GraphNodeId, ()>,
86}
87
88impl DfirGraph {
90 pub fn new() -> Self {
92 Default::default()
93 }
94}
95
96impl DfirGraph {
98 pub fn node(&self, node_id: GraphNodeId) -> &GraphNode {
100 self.nodes.get(node_id).expect("Node not found.")
101 }
102
103 pub fn node_op_inst(&self, node_id: GraphNodeId) -> Option<&OperatorInstance> {
108 self.operator_instances.get(node_id)
109 }
110
111 pub fn node_varname(&self, node_id: GraphNodeId) -> Option<&Varname> {
113 self.node_varnames.get(node_id)
114 }
115
116 pub fn node_subgraph(&self, node_id: GraphNodeId) -> Option<GraphSubgraphId> {
118 self.node_subgraph.get(node_id).copied()
119 }
120
121 pub fn node_degree_in(&self, node_id: GraphNodeId) -> usize {
123 self.graph.degree_in(node_id)
124 }
125
126 pub fn node_degree_out(&self, node_id: GraphNodeId) -> usize {
128 self.graph.degree_out(node_id)
129 }
130
131 pub fn node_successors(
133 &self,
134 src: GraphNodeId,
135 ) -> impl '_
136 + DoubleEndedIterator<Item = (GraphEdgeId, GraphNodeId)>
137 + ExactSizeIterator
138 + FusedIterator
139 + Clone
140 + Debug {
141 self.graph.successors(src)
142 }
143
144 pub fn node_predecessors(
146 &self,
147 dst: GraphNodeId,
148 ) -> impl '_
149 + DoubleEndedIterator<Item = (GraphEdgeId, GraphNodeId)>
150 + ExactSizeIterator
151 + FusedIterator
152 + Clone
153 + Debug {
154 self.graph.predecessors(dst)
155 }
156
157 pub fn node_successor_edges(
159 &self,
160 src: GraphNodeId,
161 ) -> impl '_
162 + DoubleEndedIterator<Item = GraphEdgeId>
163 + ExactSizeIterator
164 + FusedIterator
165 + Clone
166 + Debug {
167 self.graph.successor_edges(src)
168 }
169
170 pub fn node_predecessor_edges(
172 &self,
173 dst: GraphNodeId,
174 ) -> impl '_
175 + DoubleEndedIterator<Item = GraphEdgeId>
176 + ExactSizeIterator
177 + FusedIterator
178 + Clone
179 + Debug {
180 self.graph.predecessor_edges(dst)
181 }
182
183 pub fn node_successor_nodes(
185 &self,
186 src: GraphNodeId,
187 ) -> impl '_
188 + DoubleEndedIterator<Item = GraphNodeId>
189 + ExactSizeIterator
190 + FusedIterator
191 + Clone
192 + Debug {
193 self.graph.successor_vertices(src)
194 }
195
196 pub fn node_predecessor_nodes(
198 &self,
199 dst: GraphNodeId,
200 ) -> impl '_
201 + DoubleEndedIterator<Item = GraphNodeId>
202 + ExactSizeIterator
203 + FusedIterator
204 + Clone
205 + Debug {
206 self.graph.predecessor_vertices(dst)
207 }
208
209 pub fn node_ids(&self) -> slotmap::basic::Keys<'_, GraphNodeId, GraphNode> {
211 self.nodes.keys()
212 }
213
214 pub fn nodes(&self) -> slotmap::basic::Iter<'_, GraphNodeId, GraphNode> {
216 self.nodes.iter()
217 }
218
219 pub fn insert_node(
221 &mut self,
222 node: GraphNode,
223 varname_opt: Option<Ident>,
224 loop_opt: Option<GraphLoopId>,
225 ) -> GraphNodeId {
226 let node_id = self.nodes.insert(node);
227 if let Some(varname) = varname_opt {
228 self.node_varnames.insert(node_id, Varname(varname));
229 }
230 if let Some(loop_id) = loop_opt {
231 self.node_loops.insert(node_id, loop_id);
232 self.loop_nodes[loop_id].push(node_id);
233 }
234 node_id
235 }
236
237 pub fn insert_node_op_inst(&mut self, node_id: GraphNodeId, op_inst: OperatorInstance) {
239 assert!(matches!(
240 self.nodes.get(node_id),
241 Some(GraphNode::Operator(_))
242 ));
243 let old_inst = self.operator_instances.insert(node_id, op_inst);
244 assert!(old_inst.is_none());
245 }
246
247 pub fn insert_node_op_insts_all(&mut self, diagnostics: &mut Diagnostics) {
249 let mut op_insts = Vec::new();
250 for (node_id, node) in self.nodes() {
251 let GraphNode::Operator(operator) = node else {
252 continue;
253 };
254 if self.node_op_inst(node_id).is_some() {
255 continue;
256 };
257
258 let Some(op_constraints) = find_op_op_constraints(operator) else {
260 diagnostics.push(Diagnostic::spanned(
261 operator.path.span(),
262 Level::Error,
263 format!("Unknown operator `{}`", operator.name_string()),
264 ));
265 continue;
266 };
267
268 let (input_ports, output_ports) = {
270 let mut input_edges: Vec<(&PortIndexValue, GraphNodeId)> = self
271 .node_predecessors(node_id)
272 .map(|(edge_id, pred_id)| (self.edge_ports(edge_id).1, pred_id))
273 .collect();
274 input_edges.sort();
276 let input_ports: Vec<PortIndexValue> = input_edges
277 .into_iter()
278 .map(|(port, _pred)| port)
279 .cloned()
280 .collect();
281
282 let mut output_edges: Vec<(&PortIndexValue, GraphNodeId)> = self
284 .node_successors(node_id)
285 .map(|(edge_id, succ)| (self.edge_ports(edge_id).0, succ))
286 .collect();
287 output_edges.sort();
289 let output_ports: Vec<PortIndexValue> = output_edges
290 .into_iter()
291 .map(|(port, _succ)| port)
292 .cloned()
293 .collect();
294
295 (input_ports, output_ports)
296 };
297
298 let generics = get_operator_generics(diagnostics, operator);
300 {
302 let generics_span = generics
304 .generic_args
305 .as_ref()
306 .map(Spanned::span)
307 .unwrap_or_else(|| operator.path.span());
308
309 if !op_constraints
310 .persistence_args
311 .contains(&generics.persistence_args.len())
312 {
313 diagnostics.push(Diagnostic::spanned(
314 generics.persistence_args_span().unwrap_or(generics_span),
315 Level::Error,
316 format!(
317 "`{}` should have {} persistence lifetime arguments, actually has {}.",
318 op_constraints.name,
319 op_constraints.persistence_args.human_string(),
320 generics.persistence_args.len()
321 ),
322 ));
323 }
324 if !op_constraints.type_args.contains(&generics.type_args.len()) {
325 diagnostics.push(Diagnostic::spanned(
326 generics.type_args_span().unwrap_or(generics_span),
327 Level::Error,
328 format!(
329 "`{}` should have {} generic type arguments, actually has {}.",
330 op_constraints.name,
331 op_constraints.type_args.human_string(),
332 generics.type_args.len()
333 ),
334 ));
335 }
336 }
337
338 op_insts.push((
339 node_id,
340 OperatorInstance {
341 op_constraints,
342 input_ports,
343 output_ports,
344 singletons_referenced: operator.singletons_referenced.clone(),
345 generics,
346 arguments_pre: operator.args.clone(),
347 arguments_raw: operator.args_raw.clone(),
348 },
349 ));
350 }
351
352 for (node_id, op_inst) in op_insts {
353 self.insert_node_op_inst(node_id, op_inst);
354 }
355 }
356
357 pub fn insert_intermediate_node(
369 &mut self,
370 edge_id: GraphEdgeId,
371 new_node: GraphNode,
372 ) -> (GraphNodeId, GraphEdgeId) {
373 let span = Some(new_node.span());
374
375 let op_inst_opt = 'oc: {
377 let GraphNode::Operator(operator) = &new_node else {
378 break 'oc None;
379 };
380 let Some(op_constraints) = find_op_op_constraints(operator) else {
381 break 'oc None;
382 };
383 let (input_port, output_port) = self.ports.get(edge_id).cloned().unwrap();
384
385 let mut dummy_diagnostics = Diagnostics::new();
386 let generics = get_operator_generics(&mut dummy_diagnostics, operator);
387 assert!(dummy_diagnostics.is_empty());
388
389 Some(OperatorInstance {
390 op_constraints,
391 input_ports: vec![input_port],
392 output_ports: vec![output_port],
393 singletons_referenced: operator.singletons_referenced.clone(),
394 generics,
395 arguments_pre: operator.args.clone(),
396 arguments_raw: operator.args_raw.clone(),
397 })
398 };
399
400 let node_id = self.nodes.insert(new_node);
402 if let Some(op_inst) = op_inst_opt {
404 self.operator_instances.insert(node_id, op_inst);
405 }
406 let (e0, e1) = self
408 .graph
409 .insert_intermediate_vertex(node_id, edge_id)
410 .unwrap();
411
412 let (src_idx, dst_idx) = self.ports.remove(edge_id).unwrap();
414 self.ports
415 .insert(e0, (src_idx, PortIndexValue::Elided(span)));
416 self.ports
417 .insert(e1, (PortIndexValue::Elided(span), dst_idx));
418
419 (node_id, e1)
420 }
421
422 pub fn remove_intermediate_node(&mut self, node_id: GraphNodeId) {
425 assert_eq!(
426 1,
427 self.node_degree_in(node_id),
428 "Removed intermediate node must have one predecessor"
429 );
430 assert_eq!(
431 1,
432 self.node_degree_out(node_id),
433 "Removed intermediate node must have one successor"
434 );
435 assert!(
436 self.node_subgraph.is_empty() && self.subgraph_nodes.is_empty(),
437 "Should not remove intermediate node after subgraph partitioning"
438 );
439
440 assert!(self.nodes.remove(node_id).is_some());
441 let (new_edge_id, (pred_edge_id, succ_edge_id)) =
442 self.graph.remove_intermediate_vertex(node_id).unwrap();
443 self.operator_instances.remove(node_id);
444 self.node_varnames.remove(node_id);
445
446 let (src_port, _) = self.ports.remove(pred_edge_id).unwrap();
447 let (_, dst_port) = self.ports.remove(succ_edge_id).unwrap();
448 self.ports.insert(new_edge_id, (src_port, dst_port));
449 }
450
451 pub(crate) fn node_color(&self, node_id: GraphNodeId) -> Option<Color> {
457 if matches!(self.node(node_id), GraphNode::Handoff { .. }) {
458 return Some(Color::Hoff);
459 }
460
461 if let GraphNode::Operator(op) = self.node(node_id)
463 && (op.name_string() == "resolve_futures_blocking"
464 || op.name_string() == "resolve_futures_blocking_ordered")
465 {
466 return Some(Color::Push);
467 }
468
469 let inn_degree = self.node_predecessor_nodes(node_id).len();
471 let out_degree = self.node_successor_nodes(node_id).len();
473
474 match (inn_degree, out_degree) {
475 (0, 0) => None, (0, 1) => Some(Color::Pull),
477 (1, 0) => Some(Color::Push),
478 (1, 1) => None, (_many, 0 | 1) => Some(Color::Pull),
480 (0 | 1, _many) => Some(Color::Push),
481 (_many, _to_many) => Some(Color::Comp),
482 }
483 }
484
485 pub fn set_operator_tag(&mut self, node_id: GraphNodeId, tag: String) {
487 self.operator_tag.insert(node_id, tag);
488 }
489}
490
491impl DfirGraph {
493 pub fn set_node_singleton_references(
496 &mut self,
497 node_id: GraphNodeId,
498 singletons_referenced: Vec<Option<GraphNodeId>>,
499 ) -> Option<Vec<Option<GraphNodeId>>> {
500 self.node_singleton_references
501 .insert(node_id, singletons_referenced)
502 }
503
504 pub fn node_singleton_references(&self, node_id: GraphNodeId) -> &[Option<GraphNodeId>] {
507 self.node_singleton_references
508 .get(node_id)
509 .map(std::ops::Deref::deref)
510 .unwrap_or_default()
511 }
512
513 pub fn node_is_singleton(&self, node_id: GraphNodeId) -> bool {
515 self.node_is_singleton.contains_key(node_id)
516 }
517
518 pub fn compute_node_singletons(&mut self) {
523 let node_ids: Vec<_> = self.node_ids().collect();
525 for node_id in node_ids {
526 let Some(op_inst) = self.operator_instances.get(node_id) else {
527 continue;
528 };
529 let op_constraints = op_inst.op_constraints;
530
531 if op_constraints.has_singleton_output {
532 self.node_is_singleton.insert(node_id, ());
533 } else if op_constraints.preserves_singleton {
534 let all_preds_singleton = self
536 .node_predecessor_nodes(node_id)
537 .all(|pred_id| self.node_is_singleton.contains_key(pred_id));
538 let has_preds = self.node_predecessor_nodes(node_id).next().is_some();
540 if has_preds && all_preds_singleton {
541 self.node_is_singleton.insert(node_id, ());
542 }
543 }
544 }
545 }
546}
547
548impl DfirGraph {
550 pub fn merge_modules(&mut self) -> Result<(), Diagnostic> {
558 let mod_bound_nodes = self
559 .nodes()
560 .filter(|(_nid, node)| matches!(node, GraphNode::ModuleBoundary { .. }))
561 .map(|(nid, _node)| nid)
562 .collect::<Vec<_>>();
563
564 for mod_bound_node in mod_bound_nodes {
565 self.remove_module_boundary(mod_bound_node)?;
566 }
567
568 Ok(())
569 }
570
571 fn remove_module_boundary(&mut self, mod_bound_node: GraphNodeId) -> Result<(), Diagnostic> {
575 assert!(
576 self.node_subgraph.is_empty() && self.subgraph_nodes.is_empty(),
577 "Should not remove intermediate node after subgraph partitioning"
578 );
579
580 let mut mod_pred_ports = BTreeMap::new();
581 let mut mod_succ_ports = BTreeMap::new();
582
583 for mod_out_edge in self.node_predecessor_edges(mod_bound_node) {
584 let (pred_port, succ_port) = self.edge_ports(mod_out_edge);
585 mod_pred_ports.insert(succ_port.clone(), (mod_out_edge, pred_port.clone()));
586 }
587
588 for mod_inn_edge in self.node_successor_edges(mod_bound_node) {
589 let (pred_port, succ_port) = self.edge_ports(mod_inn_edge);
590 mod_succ_ports.insert(pred_port.clone(), (mod_inn_edge, succ_port.clone()));
591 }
592
593 if mod_pred_ports.keys().collect::<BTreeSet<_>>()
594 != mod_succ_ports.keys().collect::<BTreeSet<_>>()
595 {
596 let GraphNode::ModuleBoundary { input, import_expr } = self.node(mod_bound_node) else {
598 panic!();
599 };
600
601 if *input {
602 return Err(Diagnostic {
603 span: *import_expr,
604 level: Level::Error,
605 message: format!(
606 "The ports into the module did not match. input: {:?}, expected: {:?}",
607 mod_pred_ports.keys().map(|x| x.to_string()).join(", "),
608 mod_succ_ports.keys().map(|x| x.to_string()).join(", ")
609 ),
610 });
611 } else {
612 return Err(Diagnostic {
613 span: *import_expr,
614 level: Level::Error,
615 message: format!(
616 "The ports out of the module did not match. output: {:?}, expected: {:?}",
617 mod_succ_ports.keys().map(|x| x.to_string()).join(", "),
618 mod_pred_ports.keys().map(|x| x.to_string()).join(", "),
619 ),
620 });
621 }
622 }
623
624 for (port, (pred_edge, pred_port)) in mod_pred_ports {
625 let (succ_edge, succ_port) = mod_succ_ports.remove(&port).unwrap();
626
627 let (src, _) = self.edge(pred_edge);
628 let (_, dst) = self.edge(succ_edge);
629 self.remove_edge(pred_edge);
630 self.remove_edge(succ_edge);
631
632 let new_edge_id = self.graph.insert_edge(src, dst);
633 self.ports.insert(new_edge_id, (pred_port, succ_port));
634 }
635
636 self.graph.remove_vertex(mod_bound_node);
637 self.nodes.remove(mod_bound_node);
638
639 Ok(())
640 }
641}
642
643impl DfirGraph {
645 pub fn edge(&self, edge_id: GraphEdgeId) -> (GraphNodeId, GraphNodeId) {
647 let (src, dst) = self.graph.edge(edge_id).expect("Edge not found.");
648 (src, dst)
649 }
650
651 pub fn edge_ports(&self, edge_id: GraphEdgeId) -> (&PortIndexValue, &PortIndexValue) {
653 let (src_port, dst_port) = self.ports.get(edge_id).expect("Edge not found.");
654 (src_port, dst_port)
655 }
656
657 pub fn edge_ids(&self) -> slotmap::basic::Keys<'_, GraphEdgeId, (GraphNodeId, GraphNodeId)> {
659 self.graph.edge_ids()
660 }
661
662 pub fn edges(
664 &self,
665 ) -> impl '_
666 + ExactSizeIterator<Item = (GraphEdgeId, (GraphNodeId, GraphNodeId))>
667 + FusedIterator
668 + Clone
669 + Debug {
670 self.graph.edges()
671 }
672
673 pub fn insert_edge(
675 &mut self,
676 src: GraphNodeId,
677 src_port: PortIndexValue,
678 dst: GraphNodeId,
679 dst_port: PortIndexValue,
680 ) -> GraphEdgeId {
681 let edge_id = self.graph.insert_edge(src, dst);
682 self.ports.insert(edge_id, (src_port, dst_port));
683 edge_id
684 }
685
686 pub fn remove_edge(&mut self, edge: GraphEdgeId) {
688 let (_src, _dst) = self.graph.remove_edge(edge).unwrap();
689 let (_src_port, _dst_port) = self.ports.remove(edge).unwrap();
690 }
691}
692
693impl DfirGraph {
695 pub fn subgraph(&self, subgraph_id: GraphSubgraphId) -> &Vec<GraphNodeId> {
697 self.subgraph_nodes
698 .get(subgraph_id)
699 .expect("Subgraph not found.")
700 }
701
702 pub fn subgraph_ids(&self) -> slotmap::basic::Keys<'_, GraphSubgraphId, Vec<GraphNodeId>> {
704 self.subgraph_nodes.keys()
705 }
706
707 pub fn subgraphs(&self) -> slotmap::basic::Iter<'_, GraphSubgraphId, Vec<GraphNodeId>> {
709 self.subgraph_nodes.iter()
710 }
711
712 pub fn insert_subgraph(
714 &mut self,
715 node_ids: Vec<GraphNodeId>,
716 ) -> Result<GraphSubgraphId, (GraphNodeId, GraphSubgraphId)> {
717 for &node_id in node_ids.iter() {
719 if let Some(&old_sg_id) = self.node_subgraph.get(node_id) {
720 return Err((node_id, old_sg_id));
721 }
722 }
723 let subgraph_id = self.subgraph_nodes.insert_with_key(|sg_id| {
724 for &node_id in node_ids.iter() {
725 self.node_subgraph.insert(node_id, sg_id);
726 }
727 node_ids
728 });
729
730 Ok(subgraph_id)
731 }
732
733 pub fn remove_from_subgraph(&mut self, node_id: GraphNodeId) -> bool {
735 if let Some(old_sg_id) = self.node_subgraph.remove(node_id) {
736 self.subgraph_nodes[old_sg_id].retain(|&other_node_id| other_node_id != node_id);
737 true
738 } else {
739 false
740 }
741 }
742
743 pub fn handoff_delay_type(&self, node_id: GraphNodeId) -> Option<DelayType> {
745 self.handoff_delay_type.get(node_id).copied()
746 }
747
748 pub fn set_handoff_delay_type(&mut self, node_id: GraphNodeId, delay_type: DelayType) {
750 self.handoff_delay_type.insert(node_id, delay_type);
751 }
752
753 fn find_pull_to_push_idx(&self, subgraph_nodes: &[GraphNodeId]) -> usize {
755 subgraph_nodes
756 .iter()
757 .position(|&node_id| {
758 self.node_color(node_id)
759 .is_some_and(|color| Color::Pull != color)
760 })
761 .unwrap_or(subgraph_nodes.len())
762 }
763}
764
765impl DfirGraph {
767 fn node_as_ident(&self, node_id: GraphNodeId, is_pred: bool) -> Ident {
769 let name = match &self.nodes[node_id] {
770 GraphNode::Operator(_) => format!("op_{:?}", node_id.data()),
771 GraphNode::Handoff { .. } => format!(
772 "hoff_{:?}_{}",
773 node_id.data(),
774 if is_pred { "recv" } else { "send" }
775 ),
776 GraphNode::ModuleBoundary { .. } => panic!(),
777 };
778 let span = match (is_pred, &self.nodes[node_id]) {
779 (_, GraphNode::Operator(operator)) => operator.span(),
780 (true, &GraphNode::Handoff { src_span, .. }) => src_span,
781 (false, &GraphNode::Handoff { dst_span, .. }) => dst_span,
782 (_, GraphNode::ModuleBoundary { .. }) => panic!(),
783 };
784 Ident::new(&name, span)
785 }
786
787 fn hoff_buf_ident(&self, hoff_id: GraphNodeId, span: Span) -> Ident {
789 Ident::new(&format!("hoff_{:?}_buf", hoff_id.data()), span)
790 }
791
792 fn hoff_back_ident(&self, hoff_id: GraphNodeId, span: Span) -> Ident {
794 Ident::new(&format!("hoff_{:?}_back", hoff_id.data()), span)
795 }
796
797 fn node_as_singleton_ident(&self, node_id: GraphNodeId, span: Span) -> Ident {
799 Ident::new(&format!("singleton_op_{:?}", node_id.data()), span)
800 }
801
802 fn helper_resolve_singletons(&self, node_id: GraphNodeId, span: Span) -> Vec<Ident> {
804 self.node_singleton_references(node_id)
805 .iter()
806 .map(|singleton_node_id| {
807 self.node_as_singleton_ident(
809 singleton_node_id
810 .expect("Expected singleton to be resolved but was not, this is a bug."),
811 span,
812 )
813 })
814 .collect::<Vec<_>>()
815 }
816
817 fn helper_collect_subgraph_handoffs(
820 &self,
821 ) -> SecondaryMap<GraphSubgraphId, (Vec<GraphNodeId>, Vec<GraphNodeId>)> {
822 let mut subgraph_handoffs: SecondaryMap<
824 GraphSubgraphId,
825 (Vec<GraphNodeId>, Vec<GraphNodeId>),
826 > = self
827 .subgraph_nodes
828 .keys()
829 .map(|k| (k, Default::default()))
830 .collect();
831
832 for (hoff_id, node) in self.nodes() {
834 if !matches!(node, GraphNode::Handoff { .. }) {
835 continue;
836 }
837 for (_edge, succ_id) in self.node_successors(hoff_id) {
839 let succ_sg = self.node_subgraph(succ_id).unwrap();
840 subgraph_handoffs[succ_sg].0.push(hoff_id);
841 }
842 for (_edge, pred_id) in self.node_predecessors(hoff_id) {
844 let pred_sg = self.node_subgraph(pred_id).unwrap();
845 subgraph_handoffs[pred_sg].1.push(hoff_id);
846 }
847 }
848
849 subgraph_handoffs
850 }
851
852 pub fn as_code(
867 &self,
868 root: &TokenStream,
869 include_type_guards: bool,
870 prefix: TokenStream,
871 diagnostics: &mut Diagnostics,
872 ) -> Result<TokenStream, Diagnostics> {
873 self.as_code_with_options(root, include_type_guards, true, prefix, diagnostics)
874 }
875
876 pub fn as_code_with_options(
885 &self,
886 root: &TokenStream,
887 include_type_guards: bool,
888 include_meta: bool,
889 prefix: TokenStream,
890 diagnostics: &mut Diagnostics,
891 ) -> Result<TokenStream, Diagnostics> {
892 let df = Ident::new(GRAPH, Span::call_site());
893 let context = Ident::new(CONTEXT, Span::call_site());
894
895 let handoff_nodes: Vec<_> = self
897 .nodes
898 .iter()
899 .filter_map(|(node_id, node)| match node {
900 GraphNode::Operator(_) => None,
901 &GraphNode::Handoff { src_span, dst_span } => Some((node_id, (src_span, dst_span))),
902 GraphNode::ModuleBoundary { .. } => panic!(),
903 })
904 .collect();
905
906 let buffer_code: Vec<TokenStream> = handoff_nodes
907 .iter()
908 .map(|&(node_id, (src_span, dst_span))| {
909 let span = src_span.join(dst_span).unwrap_or(src_span);
910 let buf_ident = self.hoff_buf_ident(node_id, span);
911 quote_spanned! {span=>
912 let mut #buf_ident: Vec<_> = Vec::new();
913 }
914 })
915 .collect();
916
917 let back_buffer_code: Vec<TokenStream> = handoff_nodes
922 .iter()
923 .filter(|(node_id, _)| self.handoff_delay_type(*node_id).is_some())
924 .map(|&(node_id, (src_span, dst_span))| {
925 let span = src_span.join(dst_span).unwrap_or(src_span);
926 let back_ident = self.hoff_back_ident(node_id, span);
927 quote_spanned! {span=>
928 let mut #back_ident: Vec<_> = Vec::new();
929 }
930 })
931 .collect();
932
933 let subgraph_handoffs = self.helper_collect_subgraph_handoffs();
935
936 let mut defer_tick_buf_idents: Vec<Ident> = Vec::new();
947 let mut back_edge_hoff_ids: BTreeSet<GraphNodeId> = BTreeSet::new();
948 let all_subgraphs = {
949 let mut sg_preds = SecondaryMap::<_, Vec<_>>::with_capacity(self.subgraph_nodes.len());
951 for (hoff_id, node) in self.nodes() {
952 if !matches!(node, GraphNode::Handoff { .. }) {
953 continue;
955 }
956 assert_eq!(1, self.node_successors(hoff_id).len());
957 assert_eq!(1, self.node_predecessors(hoff_id).len());
958 let (_edge_id, pred) = self.node_predecessors(hoff_id).next().unwrap();
959 let (_edge_id, succ) = self.node_successors(hoff_id).next().unwrap();
960 let pred_sg = self.node_subgraph(pred).unwrap();
961 let succ_sg = self.node_subgraph(succ).unwrap();
962 if pred_sg == succ_sg {
963 panic!("bug: unexpected subgraph self-handoff cycle");
964 }
965 if let Some(delay_type) = self.handoff_delay_type(hoff_id) {
966 debug_assert!(matches!(delay_type, DelayType::Tick | DelayType::TickLazy));
967 back_edge_hoff_ids.insert(hoff_id);
970
971 if !matches!(delay_type, DelayType::TickLazy) {
973 defer_tick_buf_idents.push(self.hoff_buf_ident(hoff_id, node.span()));
974 }
975 } else {
976 sg_preds.entry(succ_sg).unwrap().or_default().push(pred_sg);
977 }
978 }
979
980 for dst_id in self.node_ids() {
983 for src_ref_id in self
984 .node_singleton_references(dst_id)
985 .iter()
986 .copied()
987 .flatten()
988 {
989 let src_sg = self
990 .node_subgraph(src_ref_id)
991 .expect("bug: singleton ref node must belong to a subgraph");
992 let dst_sg = self
993 .node_subgraph(dst_id)
994 .expect("bug: singleton ref consumer must belong to a subgraph");
995 if src_sg != dst_sg {
996 sg_preds.entry(dst_sg).unwrap().or_default().push(src_sg);
997 }
998 }
999 }
1000
1001 let topo_sort = super::graph_algorithms::topo_sort(self.subgraph_ids(), |sg_id| {
1002 sg_preds.get(sg_id).into_iter().flatten().copied()
1003 })
1004 .expect("bug: unexpected cycle between subgraphs within the tick");
1005
1006 topo_sort
1007 .into_iter()
1008 .map(|sg_id| (sg_id, self.subgraph(sg_id)))
1009 .collect::<Vec<_>>()
1010 };
1011
1012 let back_edge_swap_code: Vec<TokenStream> = back_edge_hoff_ids
1016 .iter()
1017 .map(|&hoff_id| {
1018 let span = self.nodes[hoff_id].span();
1019 let buf_ident = self.hoff_buf_ident(hoff_id, span);
1020 let back_ident = self.hoff_back_ident(hoff_id, span);
1021 quote_spanned! {span=>
1022 ::std::mem::swap(&mut #buf_ident, &mut #back_ident);
1023 }
1024 })
1025 .collect();
1026
1027 let mut op_prologue_code = Vec::new();
1028 let mut op_tick_end_code = Vec::new();
1029 let mut subgraph_blocks = Vec::new();
1030 {
1031 for &(subgraph_id, subgraph_nodes) in all_subgraphs.iter() {
1032 let sg_metrics_ffi = subgraph_id.data().as_ffi();
1033 let (recv_hoffs, send_hoffs) = &subgraph_handoffs[subgraph_id];
1034
1035 let recv_port_idents: Vec<Ident> = recv_hoffs
1037 .iter()
1038 .map(|&hoff_id| self.node_as_ident(hoff_id, true))
1039 .collect();
1040 let send_port_idents: Vec<Ident> = send_hoffs
1041 .iter()
1042 .map(|&hoff_id| self.node_as_ident(hoff_id, false))
1043 .collect();
1044
1045 let recv_buf_idents: Vec<Ident> = recv_hoffs
1047 .iter()
1048 .map(|&hoff_id| self.hoff_buf_ident(hoff_id, self.nodes[hoff_id].span()))
1049 .collect();
1050 let send_buf_idents: Vec<Ident> = send_hoffs
1051 .iter()
1052 .map(|&hoff_id| self.hoff_buf_ident(hoff_id, self.nodes[hoff_id].span()))
1053 .collect();
1054
1055 let recv_port_code: Vec<TokenStream> = recv_port_idents
1059 .iter()
1060 .zip(recv_buf_idents.iter())
1061 .zip(recv_hoffs.iter())
1062 .map(|((port_ident, buf_ident), &hoff_id)| {
1063 let hoff_ffi = hoff_id.data().as_ffi();
1064 let work_done = Ident::new("__dfir_work_done", Span::call_site());
1068 let metrics = Ident::new("__dfir_metrics", Span::call_site());
1069 let drain_ident = if back_edge_hoff_ids.contains(&hoff_id) {
1072 self.hoff_back_ident(hoff_id, buf_ident.span())
1073 } else {
1074 buf_ident.clone()
1075 };
1076 quote_spanned! {port_ident.span()=>
1077 {
1078 let hoff_len = #drain_ident.len();
1079 if hoff_len > 0 {
1080 #work_done = true;
1081 }
1082 let hoff_metrics = &#metrics.handoffs[
1083 #root::slotmap::KeyData::from_ffi(#hoff_ffi).into()
1084 ];
1085 hoff_metrics.total_items_count.update(|x| x + hoff_len);
1086 hoff_metrics.curr_items_count.set(hoff_len);
1087 }
1088 let #port_ident = #root::dfir_pipes::pull::iter(#drain_ident.drain(..));
1089 }
1090 })
1091 .collect();
1092
1093 let send_port_code: Vec<TokenStream> = send_port_idents
1095 .iter()
1096 .zip(send_buf_idents.iter())
1097 .map(|(port_ident, buf_ident)| {
1098 quote_spanned! {port_ident.span()=>
1099 let #port_ident = #root::dfir_pipes::push::vec_push(&mut #buf_ident);
1100 }
1101 })
1102 .collect();
1103
1104 let loop_id = self.node_loop(subgraph_nodes[0]);
1106
1107 let mut subgraph_op_iter_code = Vec::new();
1108 let mut subgraph_op_iter_after_code = Vec::new();
1109 {
1110 let pull_to_push_idx = self.find_pull_to_push_idx(subgraph_nodes);
1111
1112 let (pull_half, push_half) = subgraph_nodes.split_at(pull_to_push_idx);
1113 let nodes_iter = pull_half.iter().chain(push_half.iter().rev());
1114
1115 for (idx, &node_id) in nodes_iter.enumerate() {
1116 let node = &self.nodes[node_id];
1117 assert!(
1118 matches!(node, GraphNode::Operator(_)),
1119 "Handoffs are not part of subgraphs."
1120 );
1121 let op_inst = &self.operator_instances[node_id];
1122
1123 let op_span = node.span();
1124 let op_name = op_inst.op_constraints.name;
1125 let root = change_spans(root.clone(), op_span);
1127 let op_constraints = OPERATORS
1128 .iter()
1129 .find(|op| op_name == op.name)
1130 .unwrap_or_else(|| panic!("Failed to find op: {}", op_name));
1131
1132 let ident = self.node_as_ident(node_id, false);
1133
1134 {
1135 let mut input_edges = self
1138 .graph
1139 .predecessor_edges(node_id)
1140 .map(|edge_id| (self.edge_ports(edge_id).1, edge_id))
1141 .collect::<Vec<_>>();
1142 input_edges.sort();
1144
1145 let inputs = input_edges
1146 .iter()
1147 .map(|&(_port, edge_id)| {
1148 let (pred, _) = self.edge(edge_id);
1149 self.node_as_ident(pred, true)
1150 })
1151 .collect::<Vec<_>>();
1152
1153 let mut output_edges = self
1155 .graph
1156 .successor_edges(node_id)
1157 .map(|edge_id| (&self.ports[edge_id].0, edge_id))
1158 .collect::<Vec<_>>();
1159 output_edges.sort();
1161
1162 let outputs = output_edges
1163 .iter()
1164 .map(|&(_port, edge_id)| {
1165 let (_, succ) = self.edge(edge_id);
1166 self.node_as_ident(succ, false)
1167 })
1168 .collect::<Vec<_>>();
1169
1170 let is_pull = idx < pull_to_push_idx;
1171
1172 let singleton_output_ident = &if op_constraints.has_singleton_output {
1173 self.node_as_singleton_ident(node_id, op_span)
1174 } else {
1175 Ident::new(&format!("{}_has_no_singleton_output", op_name), op_span)
1177 };
1178
1179 let df_local = &Ident::new(GRAPH, op_span.resolved_at(df.span()));
1188 let context = &Ident::new(CONTEXT, op_span.resolved_at(context.span()));
1189
1190 let singletons_resolved =
1191 self.helper_resolve_singletons(node_id, op_span);
1192 let arguments = &process_singletons::postprocess_singletons(
1193 op_inst.arguments_raw.clone(),
1194 singletons_resolved.clone(),
1195 );
1196 let arguments_handles =
1197 &process_singletons::postprocess_singletons_handles(
1198 op_inst.arguments_raw.clone(),
1199 singletons_resolved.clone(),
1200 );
1201
1202 let source_tag = 'a: {
1203 if let Some(tag) = self.operator_tag.get(node_id).cloned() {
1204 break 'a tag;
1205 }
1206
1207 #[cfg(nightly)]
1208 if proc_macro::is_available() {
1209 let op_span = op_span.unwrap();
1210 break 'a format!(
1211 "loc_{}_{}_{}_{}_{}",
1212 crate::pretty_span::make_source_path_relative(
1213 &op_span.file()
1214 )
1215 .display()
1216 .to_string()
1217 .replace(|x: char| !x.is_ascii_alphanumeric(), "_"),
1218 op_span.start().line(),
1219 op_span.start().column(),
1220 op_span.end().line(),
1221 op_span.end().column(),
1222 );
1223 }
1224
1225 format!(
1226 "loc_nopath_{}_{}_{}_{}",
1227 op_span.start().line,
1228 op_span.start().column,
1229 op_span.end().line,
1230 op_span.end().column
1231 )
1232 };
1233
1234 let work_fn = format_ident!(
1235 "{}__{}__{}",
1236 ident,
1237 op_name,
1238 source_tag,
1239 span = op_span
1240 );
1241 let work_fn_async = format_ident!("{}__async", work_fn, span = op_span);
1242
1243 let context_args = WriteContextArgs {
1244 root: &root,
1245 df_ident: df_local,
1246 context,
1247 subgraph_id,
1248 node_id,
1249 loop_id,
1250 op_span,
1251 op_tag: self.operator_tag.get(node_id).cloned(),
1252 work_fn: &work_fn,
1253 work_fn_async: &work_fn_async,
1254 ident: &ident,
1255 is_pull,
1256 inputs: &inputs,
1257 outputs: &outputs,
1258 singleton_output_ident,
1259 op_name,
1260 op_inst,
1261 arguments,
1262 arguments_handles,
1263 };
1264
1265 let write_result =
1266 (op_constraints.write_fn)(&context_args, diagnostics);
1267 let OperatorWriteOutput {
1268 write_prologue,
1269 write_iterator,
1270 write_iterator_after,
1271 write_tick_end,
1272 } = write_result.unwrap_or_else(|()| {
1273 assert!(
1274 diagnostics.has_error(),
1275 "Operator `{}` returned `Err` but emitted no diagnostics, this is a bug.",
1276 op_name,
1277 );
1278 OperatorWriteOutput {
1279 write_iterator: null_write_iterator_fn(&context_args),
1280 ..Default::default()
1281 }
1282 });
1283
1284 op_prologue_code.push(syn::parse_quote! {
1285 #[allow(non_snake_case)]
1286 #[inline(always)]
1287 fn #work_fn<T>(thunk: impl ::std::ops::FnOnce() -> T) -> T {
1288 thunk()
1289 }
1290
1291 #[allow(non_snake_case)]
1292 #[inline(always)]
1293 async fn #work_fn_async<T>(
1294 thunk: impl ::std::future::Future<Output = T>,
1295 ) -> T {
1296 thunk.await
1297 }
1298 });
1299 op_prologue_code.push(write_prologue);
1300 op_tick_end_code.push(write_tick_end);
1301 subgraph_op_iter_code.push(write_iterator);
1302
1303 if include_type_guards {
1304 let type_guard = if is_pull {
1305 quote_spanned! {op_span=>
1306 let #ident = {
1307 #[allow(non_snake_case)]
1308 #[inline(always)]
1309 pub fn #work_fn<Item, Input>(input: Input)
1310 -> impl #root::dfir_pipes::pull::Pull<Item = Item, Meta = (), CanPend = Input::CanPend, CanEnd = Input::CanEnd>
1311 where
1312 Input: #root::dfir_pipes::pull::Pull<Item = Item, Meta = ()>,
1313 {
1314 #root::pin_project_lite::pin_project! {
1315 #[repr(transparent)]
1316 struct Pull<Item, Input: #root::dfir_pipes::pull::Pull<Item = Item>> {
1317 #[pin]
1318 inner: Input
1319 }
1320 }
1321
1322 impl<Item, Input> #root::dfir_pipes::pull::Pull for Pull<Item, Input>
1323 where
1324 Input: #root::dfir_pipes::pull::Pull<Item = Item>,
1325 {
1326 type Ctx<'ctx> = Input::Ctx<'ctx>;
1327
1328 type Item = Item;
1329 type Meta = Input::Meta;
1330 type CanPend = Input::CanPend;
1331 type CanEnd = Input::CanEnd;
1332
1333 #[inline(always)]
1334 fn pull(
1335 self: ::std::pin::Pin<&mut Self>,
1336 ctx: &mut Self::Ctx<'_>,
1337 ) -> #root::dfir_pipes::pull::PullStep<Self::Item, Self::Meta, Self::CanPend, Self::CanEnd> {
1338 #root::dfir_pipes::pull::Pull::pull(self.project().inner, ctx)
1339 }
1340
1341 #[inline(always)]
1342 fn size_hint(&self) -> (usize, Option<usize>) {
1343 #root::dfir_pipes::pull::Pull::size_hint(&self.inner)
1344 }
1345 }
1346
1347 Pull {
1348 inner: input
1349 }
1350 }
1351 #work_fn::<_, _>( #ident )
1352 };
1353 }
1354 } else {
1355 quote_spanned! {op_span=>
1356 let #ident = {
1357 #[allow(non_snake_case)]
1358 #[inline(always)]
1359 pub fn #work_fn<Item, Psh>(psh: Psh) -> impl #root::dfir_pipes::push::Push<Item, (), CanPend = Psh::CanPend>
1360 where
1361 Psh: #root::dfir_pipes::push::Push<Item, ()>
1362 {
1363 #root::pin_project_lite::pin_project! {
1364 #[repr(transparent)]
1365 struct PushGuard<Psh> {
1366 #[pin]
1367 inner: Psh,
1368 }
1369 }
1370
1371 impl<Item, Psh> #root::dfir_pipes::push::Push<Item, ()> for PushGuard<Psh>
1372 where
1373 Psh: #root::dfir_pipes::push::Push<Item, ()>,
1374 {
1375 type Ctx<'ctx> = Psh::Ctx<'ctx>;
1376
1377 type CanPend = Psh::CanPend;
1378
1379 #[inline(always)]
1380 fn poll_ready(
1381 self: ::std::pin::Pin<&mut Self>,
1382 ctx: &mut Self::Ctx<'_>,
1383 ) -> #root::dfir_pipes::push::PushStep<Self::CanPend> {
1384 #root::dfir_pipes::push::Push::poll_ready(self.project().inner, ctx)
1385 }
1386
1387 #[inline(always)]
1388 fn start_send(
1389 self: ::std::pin::Pin<&mut Self>,
1390 item: Item,
1391 meta: (),
1392 ) {
1393 #root::dfir_pipes::push::Push::start_send(self.project().inner, item, meta)
1394 }
1395
1396 #[inline(always)]
1397 fn poll_flush(
1398 self: ::std::pin::Pin<&mut Self>,
1399 ctx: &mut Self::Ctx<'_>,
1400 ) -> #root::dfir_pipes::push::PushStep<Self::CanPend> {
1401 #root::dfir_pipes::push::Push::poll_flush(self.project().inner, ctx)
1402 }
1403
1404 #[inline(always)]
1405 fn size_hint(
1406 self: ::std::pin::Pin<&mut Self>,
1407 hint: (usize, Option<usize>),
1408 ) {
1409 #root::dfir_pipes::push::Push::size_hint(self.project().inner, hint)
1410 }
1411 }
1412
1413 PushGuard {
1414 inner: psh
1415 }
1416 }
1417 #work_fn( #ident )
1418 };
1419 }
1420 };
1421 subgraph_op_iter_code.push(type_guard);
1422 }
1423 subgraph_op_iter_after_code.push(write_iterator_after);
1424 }
1425 }
1426
1427 {
1428 let pull_ident = if 0 < pull_to_push_idx {
1430 self.node_as_ident(subgraph_nodes[pull_to_push_idx - 1], false)
1431 } else {
1432 recv_port_idents[0].clone()
1434 };
1435
1436 #[rustfmt::skip]
1437 let push_ident = if let Some(&node_id) =
1438 subgraph_nodes.get(pull_to_push_idx)
1439 {
1440 self.node_as_ident(node_id, false)
1441 } else if 1 == send_port_idents.len() {
1442 send_port_idents[0].clone()
1444 } else {
1445 diagnostics.push(Diagnostic::spanned(
1446 pull_ident.span(),
1447 Level::Error,
1448 "Degenerate subgraph detected, is there a disconnected `null()` or other degenerate pipeline somewhere?",
1449 ));
1450 continue;
1451 };
1452
1453 let pivot_span = pull_ident
1455 .span()
1456 .join(push_ident.span())
1457 .unwrap_or_else(|| push_ident.span());
1458 let pivot_fn_ident = Ident::new(
1459 &format!("pivot_run_sg_{:?}", subgraph_id.data()),
1460 pivot_span,
1461 );
1462 let root = change_spans(root.clone(), pivot_span);
1463 subgraph_op_iter_code.push(quote_spanned! {pivot_span=>
1464 #[inline(always)]
1465 fn #pivot_fn_ident<Pul, Psh, Item>(pull: Pul, push: Psh)
1466 -> impl ::std::future::Future<Output = ()>
1467 where
1468 Pul: #root::dfir_pipes::pull::Pull<Item = Item>,
1469 Psh: #root::dfir_pipes::push::Push<Item, Pul::Meta>,
1470 {
1471 #root::dfir_pipes::pull::Pull::send_push(pull, push)
1472 }
1473 (#pivot_fn_ident)(#pull_ident, #push_ident).await;
1474 });
1475 }
1476 };
1477
1478 let sg_fut_ident = subgraph_id.as_ident(Span::call_site());
1482
1483 let send_metrics_code: Vec<TokenStream> = send_hoffs
1485 .iter()
1486 .zip(send_buf_idents.iter())
1487 .map(|(&hoff_id, buf_ident)| {
1488 let hoff_ffi = hoff_id.data().as_ffi();
1489 quote! {
1490 __dfir_metrics.handoffs[
1491 #root::slotmap::KeyData::from_ffi(#hoff_ffi).into()
1492 ].curr_items_count.set(#buf_ident.len());
1493 }
1494 })
1495 .collect();
1496
1497 subgraph_blocks.push(quote! {
1498 let #sg_fut_ident = async {
1499 let #context = &#df;
1500 #( #recv_port_code )*
1501 #( #send_port_code )*
1502 #( #subgraph_op_iter_code )*
1503 #( #subgraph_op_iter_after_code )*
1504 };
1505 {
1506 let sg_metrics = &__dfir_metrics.subgraphs[
1507 #root::slotmap::KeyData::from_ffi(#sg_metrics_ffi).into()
1508 ];
1509 #root::scheduled::metrics::InstrumentSubgraph::new(
1510 #sg_fut_ident, sg_metrics
1511 ).await;
1512 sg_metrics.total_run_count.update(|x| x + 1);
1513 }
1514 #( #send_metrics_code )*
1515 });
1516
1517 }
1520 }
1521
1522 if diagnostics.has_error() {
1523 return Err(std::mem::take(diagnostics));
1524 }
1525 let _ = diagnostics; let (meta_graph_arg, diagnostics_arg) = if include_meta {
1528 let meta_graph_json = serde_json::to_string(&self).unwrap();
1529 let meta_graph_json = Literal::string(&meta_graph_json);
1530
1531 let serde_diagnostics: Vec<_> = diagnostics.iter().map(Diagnostic::to_serde).collect();
1532 let diagnostics_json = serde_json::to_string(&*serde_diagnostics).unwrap();
1533 let diagnostics_json = Literal::string(&diagnostics_json);
1534
1535 (
1536 quote! { Some(#meta_graph_json) },
1537 quote! { Some(#diagnostics_json) },
1538 )
1539 } else {
1540 (quote! { None }, quote! { None })
1541 };
1542
1543 let metrics_init_code = {
1545 let handoff_inits = handoff_nodes.iter().map(|&(node_id, _)| {
1546 let ffi = node_id.data().as_ffi();
1547 quote! {
1548 dfir_metrics.handoffs.insert(
1549 #root::slotmap::KeyData::from_ffi(#ffi).into(),
1550 ::std::default::Default::default(),
1551 );
1552 }
1553 });
1554 let subgraph_inits = all_subgraphs.iter().map(|&(sg_id, _)| {
1555 let ffi = sg_id.data().as_ffi();
1556 quote! {
1557 dfir_metrics.subgraphs.insert(
1558 #root::slotmap::KeyData::from_ffi(#ffi).into(),
1559 ::std::default::Default::default(),
1560 );
1561 }
1562 });
1563 handoff_inits.chain(subgraph_inits).collect::<Vec<_>>()
1564 };
1565
1566 Ok(quote! {
1569 {
1570 #prefix
1571
1572 use #root::{var_expr, var_args};
1573
1574 let __dfir_wake_state = ::std::sync::Arc::new(
1575 #root::scheduled::context::WakeState::default()
1576 );
1577
1578 let __dfir_metrics = {
1579 let mut dfir_metrics = #root::scheduled::metrics::DfirMetrics::default();
1580 #( #metrics_init_code )*
1581 ::std::rc::Rc::new(dfir_metrics)
1582 };
1583
1584 #[allow(unused_mut)]
1585 let mut #df = #root::scheduled::context::Context::new(
1586 ::std::clone::Clone::clone(&__dfir_wake_state),
1587 __dfir_metrics,
1588 );
1589
1590 #( #buffer_code )*
1591 #( #back_buffer_code )*
1592 #( #op_prologue_code )*
1593
1594 let mut __dfir_work_done = true;
1599 #[allow(unused_qualifications, unused_mut, unused_variables, clippy::await_holding_refcell_ref, clippy::deref_addrof)]
1600 let __dfir_inline_tick = async move |#df: &mut #root::scheduled::context::Context| {
1601 let __dfir_metrics = #df.metrics();
1602 #( #back_edge_swap_code )*
1605 #( #subgraph_blocks )*
1606
1607 if false #( || !#defer_tick_buf_idents.is_empty() )* {
1610 #df.schedule_subgraph(true);
1611 }
1612
1613 #( #op_tick_end_code )*
1615
1616 #df.__end_tick();
1617 ::std::mem::take(&mut __dfir_work_done)
1618 };
1619 #root::scheduled::context::Dfir::new(
1620 __dfir_inline_tick,
1621 #df,
1622 #meta_graph_arg,
1623 #diagnostics_arg,
1624 )
1625 }
1626 })
1627 }
1628
1629 pub fn node_color_map(&self) -> SparseSecondaryMap<GraphNodeId, Color> {
1632 let mut node_color_map: SparseSecondaryMap<GraphNodeId, Color> = self
1633 .node_ids()
1634 .filter_map(|node_id| {
1635 let op_color = self.node_color(node_id)?;
1636 Some((node_id, op_color))
1637 })
1638 .collect();
1639
1640 for sg_nodes in self.subgraph_nodes.values() {
1642 let pull_to_push_idx = self.find_pull_to_push_idx(sg_nodes);
1643
1644 for (idx, node_id) in sg_nodes.iter().copied().enumerate() {
1645 let is_pull = idx < pull_to_push_idx;
1646 node_color_map.insert(node_id, if is_pull { Color::Pull } else { Color::Push });
1647 }
1648 }
1649
1650 node_color_map
1651 }
1652
1653 pub fn to_mermaid(&self, write_config: &WriteConfig) -> String {
1655 let mut output = String::new();
1656 self.write_mermaid(&mut output, write_config).unwrap();
1657 output
1658 }
1659
1660 pub fn write_mermaid(
1662 &self,
1663 output: impl std::fmt::Write,
1664 write_config: &WriteConfig,
1665 ) -> std::fmt::Result {
1666 let mut graph_write = Mermaid::new(output);
1667 self.write_graph(&mut graph_write, write_config)
1668 }
1669
1670 pub fn to_dot(&self, write_config: &WriteConfig) -> String {
1672 let mut output = String::new();
1673 let mut graph_write = Dot::new(&mut output);
1674 self.write_graph(&mut graph_write, write_config).unwrap();
1675 output
1676 }
1677
1678 pub fn write_dot(
1680 &self,
1681 output: impl std::fmt::Write,
1682 write_config: &WriteConfig,
1683 ) -> std::fmt::Result {
1684 let mut graph_write = Dot::new(output);
1685 self.write_graph(&mut graph_write, write_config)
1686 }
1687
1688 pub(crate) fn write_graph<W>(
1690 &self,
1691 mut graph_write: W,
1692 write_config: &WriteConfig,
1693 ) -> Result<(), W::Err>
1694 where
1695 W: GraphWrite,
1696 {
1697 fn helper_edge_label(
1698 src_port: &PortIndexValue,
1699 dst_port: &PortIndexValue,
1700 ) -> Option<String> {
1701 let src_label = match src_port {
1702 PortIndexValue::Path(path) => Some(path.to_token_stream().to_string()),
1703 PortIndexValue::Int(index) => Some(index.value.to_string()),
1704 _ => None,
1705 };
1706 let dst_label = match dst_port {
1707 PortIndexValue::Path(path) => Some(path.to_token_stream().to_string()),
1708 PortIndexValue::Int(index) => Some(index.value.to_string()),
1709 _ => None,
1710 };
1711 let label = match (src_label, dst_label) {
1712 (Some(l1), Some(l2)) => Some(format!("{}\n{}", l1, l2)),
1713 (Some(l1), None) => Some(l1),
1714 (None, Some(l2)) => Some(l2),
1715 (None, None) => None,
1716 };
1717 label
1718 }
1719
1720 let node_color_map = self.node_color_map();
1722
1723 graph_write.write_prologue()?;
1725
1726 let mut skipped_handoffs = BTreeSet::new();
1728 let mut subgraph_handoffs = <BTreeMap<GraphSubgraphId, Vec<GraphNodeId>>>::new();
1729 for (node_id, node) in self.nodes() {
1730 if matches!(node, GraphNode::Handoff { .. }) {
1731 if write_config.no_handoffs {
1732 skipped_handoffs.insert(node_id);
1733 continue;
1734 } else {
1735 let pred_node = self.node_predecessor_nodes(node_id).next().unwrap();
1736 let pred_sg = self.node_subgraph(pred_node);
1737 let succ_node = self.node_successor_nodes(node_id).next().unwrap();
1738 let succ_sg = self.node_subgraph(succ_node);
1739 if let Some((pred_sg, succ_sg)) = pred_sg.zip(succ_sg)
1740 && pred_sg == succ_sg
1741 {
1742 subgraph_handoffs.entry(pred_sg).or_default().push(node_id);
1743 }
1744 }
1745 }
1746 graph_write.write_node_definition(
1747 node_id,
1748 &if write_config.op_short_text {
1749 node.to_name_string()
1750 } else if write_config.op_text_no_imports {
1751 let full_text = node.to_pretty_string();
1753 let mut output = String::new();
1754 for sentence in full_text.split('\n') {
1755 if sentence.trim().starts_with("use") {
1756 continue;
1757 }
1758 output.push('\n');
1759 output.push_str(sentence);
1760 }
1761 output.into()
1762 } else {
1763 node.to_pretty_string()
1764 },
1765 if write_config.no_pull_push {
1766 None
1767 } else {
1768 node_color_map.get(node_id).copied()
1769 },
1770 )?;
1771 }
1772
1773 for (edge_id, (src_id, mut dst_id)) in self.edges() {
1775 if skipped_handoffs.contains(&src_id) {
1777 continue;
1778 }
1779
1780 let (src_port, mut dst_port) = self.edge_ports(edge_id);
1781 if skipped_handoffs.contains(&dst_id) {
1782 let mut handoff_succs = self.node_successors(dst_id);
1783 assert_eq!(1, handoff_succs.len());
1784 let (succ_edge, succ_node) = handoff_succs.next().unwrap();
1785 dst_id = succ_node;
1786 dst_port = self.edge_ports(succ_edge).1;
1787 }
1788
1789 let label = helper_edge_label(src_port, dst_port);
1790 let delay_type = self
1791 .node_op_inst(dst_id)
1792 .and_then(|op_inst| (op_inst.op_constraints.input_delaytype_fn)(dst_port));
1793 graph_write.write_edge(src_id, dst_id, delay_type, label.as_deref(), false)?;
1794 }
1795
1796 if !write_config.no_references {
1798 for dst_id in self.node_ids() {
1799 for src_ref_id in self
1800 .node_singleton_references(dst_id)
1801 .iter()
1802 .copied()
1803 .flatten()
1804 {
1805 let delay_type = None;
1806 let label = None;
1807 graph_write.write_edge(src_ref_id, dst_id, delay_type, label, true)?;
1808 }
1809 }
1810 }
1811
1812 let loop_subgraphs = self.subgraph_ids().map(|sg_id| {
1820 let loop_id = if write_config.no_loops {
1821 None
1822 } else {
1823 self.subgraph_loop(sg_id)
1824 };
1825 (loop_id, sg_id)
1826 });
1827 let loop_subgraphs = into_group_map(loop_subgraphs);
1828 for (loop_id, subgraph_ids) in loop_subgraphs {
1829 if let Some(loop_id) = loop_id {
1830 graph_write.write_loop_start(loop_id)?;
1831 }
1832
1833 let subgraph_varnames_nodes = subgraph_ids.into_iter().flat_map(|sg_id| {
1835 self.subgraph(sg_id).iter().copied().map(move |node_id| {
1836 let opt_sg_id = if write_config.no_subgraphs {
1837 None
1838 } else {
1839 Some(sg_id)
1840 };
1841 (opt_sg_id, (self.node_varname(node_id), node_id))
1842 })
1843 });
1844 let subgraph_varnames_nodes = into_group_map(subgraph_varnames_nodes);
1845 for (sg_id, varnames) in subgraph_varnames_nodes {
1846 if let Some(sg_id) = sg_id {
1847 graph_write.write_subgraph_start(sg_id)?;
1848 }
1849
1850 let varname_nodes = varnames.into_iter().map(|(varname, node)| {
1852 let varname = if write_config.no_varnames {
1853 None
1854 } else {
1855 varname
1856 };
1857 (varname, node)
1858 });
1859 let varname_nodes = into_group_map(varname_nodes);
1860 for (varname, node_ids) in varname_nodes {
1861 if let Some(varname) = varname {
1862 graph_write.write_varname_start(&varname.0.to_string(), sg_id)?;
1863 }
1864
1865 for node_id in node_ids {
1867 graph_write.write_node(node_id)?;
1868 }
1869
1870 if varname.is_some() {
1871 graph_write.write_varname_end()?;
1872 }
1873 }
1874
1875 if sg_id.is_some() {
1876 graph_write.write_subgraph_end()?;
1877 }
1878 }
1879
1880 if loop_id.is_some() {
1881 graph_write.write_loop_end()?;
1882 }
1883 }
1884
1885 graph_write.write_epilogue()?;
1887
1888 Ok(())
1889 }
1890
1891 pub fn surface_syntax_string(&self) -> String {
1893 let mut string = String::new();
1894 self.write_surface_syntax(&mut string).unwrap();
1895 string
1896 }
1897
1898 pub fn write_surface_syntax(&self, write: &mut impl std::fmt::Write) -> std::fmt::Result {
1900 for (key, node) in self.nodes.iter() {
1901 match node {
1902 GraphNode::Operator(op) => {
1903 writeln!(write, "{:?} = {};", key.data(), op.to_token_stream())?;
1904 }
1905 GraphNode::Handoff { .. } => {
1906 writeln!(write, "// {:?} = <handoff>;", key.data())?;
1907 }
1908 GraphNode::ModuleBoundary { .. } => panic!(),
1909 }
1910 }
1911 writeln!(write)?;
1912 for (_e, (src_key, dst_key)) in self.graph.edges() {
1913 writeln!(write, "{:?} -> {:?};", src_key.data(), dst_key.data())?;
1914 }
1915 Ok(())
1916 }
1917
1918 pub fn mermaid_string_flat(&self) -> String {
1920 let mut string = String::new();
1921 self.write_mermaid_flat(&mut string).unwrap();
1922 string
1923 }
1924
1925 pub fn write_mermaid_flat(&self, write: &mut impl std::fmt::Write) -> std::fmt::Result {
1927 writeln!(write, "flowchart TB")?;
1928 for (key, node) in self.nodes.iter() {
1929 match node {
1930 GraphNode::Operator(operator) => writeln!(
1931 write,
1932 " %% {span}\n {id:?}[\"{row_col} <tt>{code}</tt>\"]",
1933 span = PrettySpan(node.span()),
1934 id = key.data(),
1935 row_col = PrettyRowCol(node.span()),
1936 code = operator
1937 .to_token_stream()
1938 .to_string()
1939 .replace('&', "&")
1940 .replace('<', "<")
1941 .replace('>', ">")
1942 .replace('"', """)
1943 .replace('\n', "<br>"),
1944 ),
1945 GraphNode::Handoff { .. } => {
1946 writeln!(write, r#" {:?}{{"{}"}}"#, key.data(), HANDOFF_NODE_STR)
1947 }
1948 GraphNode::ModuleBoundary { .. } => {
1949 writeln!(
1950 write,
1951 r#" {:?}{{"{}"}}"#,
1952 key.data(),
1953 MODULE_BOUNDARY_NODE_STR
1954 )
1955 }
1956 }?;
1957 }
1958 writeln!(write)?;
1959 for (_e, (src_key, dst_key)) in self.graph.edges() {
1960 writeln!(write, " {:?}-->{:?}", src_key.data(), dst_key.data())?;
1961 }
1962 Ok(())
1963 }
1964}
1965
1966impl DfirGraph {
1968 pub fn loop_ids(&self) -> slotmap::basic::Keys<'_, GraphLoopId, Vec<GraphNodeId>> {
1970 self.loop_nodes.keys()
1971 }
1972
1973 pub fn loops(&self) -> slotmap::basic::Iter<'_, GraphLoopId, Vec<GraphNodeId>> {
1975 self.loop_nodes.iter()
1976 }
1977
1978 pub fn insert_loop(&mut self, parent_loop: Option<GraphLoopId>) -> GraphLoopId {
1980 let loop_id = self.loop_nodes.insert(Vec::new());
1981 self.loop_children.insert(loop_id, Vec::new());
1982 if let Some(parent_loop) = parent_loop {
1983 self.loop_parent.insert(loop_id, parent_loop);
1984 self.loop_children
1985 .get_mut(parent_loop)
1986 .unwrap()
1987 .push(loop_id);
1988 } else {
1989 self.root_loops.push(loop_id);
1990 }
1991 loop_id
1992 }
1993
1994 pub fn node_loop(&self, node_id: GraphNodeId) -> Option<GraphLoopId> {
1996 self.node_loops.get(node_id).copied()
1997 }
1998
1999 pub fn subgraph_loop(&self, subgraph_id: GraphSubgraphId) -> Option<GraphLoopId> {
2001 let &node_id = self.subgraph(subgraph_id).first().unwrap();
2002 let out = self.node_loop(node_id);
2003 debug_assert!(
2004 self.subgraph(subgraph_id)
2005 .iter()
2006 .all(|&node_id| self.node_loop(node_id) == out),
2007 "Subgraph nodes should all have the same loop context."
2008 );
2009 out
2010 }
2011
2012 pub fn loop_parent(&self, loop_id: GraphLoopId) -> Option<GraphLoopId> {
2014 self.loop_parent.get(loop_id).copied()
2015 }
2016
2017 pub fn loop_children(&self, loop_id: GraphLoopId) -> &Vec<GraphLoopId> {
2019 self.loop_children.get(loop_id).unwrap()
2020 }
2021}
2022
2023#[derive(Clone, Debug, Default)]
2025#[cfg_attr(feature = "clap-derive", derive(clap::Args))]
2026pub struct WriteConfig {
2027 #[cfg_attr(feature = "clap-derive", arg(long))]
2029 pub no_subgraphs: bool,
2030 #[cfg_attr(feature = "clap-derive", arg(long))]
2032 pub no_varnames: bool,
2033 #[cfg_attr(feature = "clap-derive", arg(long))]
2035 pub no_pull_push: bool,
2036 #[cfg_attr(feature = "clap-derive", arg(long))]
2038 pub no_handoffs: bool,
2039 #[cfg_attr(feature = "clap-derive", arg(long))]
2041 pub no_references: bool,
2042 #[cfg_attr(feature = "clap-derive", arg(long))]
2044 pub no_loops: bool,
2045
2046 #[cfg_attr(feature = "clap-derive", arg(long))]
2048 pub op_short_text: bool,
2049 #[cfg_attr(feature = "clap-derive", arg(long))]
2051 pub op_text_no_imports: bool,
2052}
2053
2054#[derive(Copy, Clone, Debug)]
2056#[cfg_attr(feature = "clap-derive", derive(clap::Parser, clap::ValueEnum))]
2057pub enum WriteGraphType {
2058 Mermaid,
2060 Dot,
2062}
2063
2064fn into_group_map<K, V>(iter: impl IntoIterator<Item = (K, V)>) -> BTreeMap<K, Vec<V>>
2066where
2067 K: Ord,
2068{
2069 let mut out: BTreeMap<_, Vec<_>> = BTreeMap::new();
2070 for (k, v) in iter {
2071 out.entry(k).or_default().push(v);
2072 }
2073 out
2074}