from typing import Any from .CFG_Node import * class CFG: def __init__(self, in_node: CFG_Node, out_node: CFG_Node): self.in_node = in_node self.out_node = out_node def to_dot(self) -> str: visited = set() visited_nodes = [] lines = ["digraph CFG {"] lines.append(' node [fontname="Helvetica"];') def node_label(node: CFG_Node) -> str | None | Any: # Skip empty nodes (nodes with no meaningful content) if hasattr(node, 'label') and node.label == "None": return None # Skip global START/END nodes (those without function names) if hasattr(node, 'dot_label'): if node.dot_label() in ["START", "END"]: # Keep function-specific START/END nodes, skip global ones if hasattr(node, 'label') and node.label and '(' in node.label and ')' in node.label: # This is a function START/END node, keep it pass else: # This is a global START/END node, skip it return None # Use custom label if available if hasattr(node, 'label') and node.label: # Remove node ID from label for certain node types if isinstance(node, (CFG_START, CFG_END, CFG_CALL, CFG_RETURN)): return node.label else: return node.label # Base label from the node base = node.dot_label() if hasattr(node, "dot_label") else "" # Semantic label from AST if node.ast_node is not None: semantic = str(node.ast_node) label_content = f"{base}\n{semantic}" if base else semantic return label_content return base if base else None def node_shape(node: CFG_Node) -> str: return node.dot_shape() if hasattr(node, "dot_shape") else "box" def node_style(node: CFG_Node) -> str: # Add styling for special node types styles = [] if hasattr(node, 'label') and node.label: if node.label.startswith('CALL') or node.label.startswith('RET'): styles.append('style=filled') styles.append('color=orange') elif node.label.startswith('START') or node.label.startswith('END'): styles.append('style=filled') styles.append('color=green') return ', '.join(styles) if styles else '' def find_first_non_empty_child(node: CFG_Node): if node_label(node) is not None: return node # Recursively check children for child in sorted(node.children, key=lambda n: n.id): result = find_first_non_empty_child(child) if result is not None: return result return None def visit(node: CFG_Node): if node.id in visited: return label = node_label(node) visited_nodes.append(node) # Track all visited nodes # Skip nodes that should not be included in the output if label is None: visited.add(node.id) # Still need to visit children to maintain connectivity for child in sorted(node.children, key=lambda n: n.id): visit(child) return visited.add(node.id) shape = node_shape(node) style = node_style(node) style_str = f", {style}" if style else "" lines.append( f' n{node.id} [label="{label}", shape={shape}{style_str}];' ) for i, child in enumerate(sorted(node.children, key=lambda n: n.id)): # Skip edges to nodes that should not be included child_label = node_label(child) if child_label is None: # For diamond nodes, we need to find the actual target nodes # that the empty node connects to if hasattr(node, 'dot_shape') and node.dot_shape() == "diamond": # Find the first non-empty descendant of this empty node actual_target = find_first_non_empty_child(child) if actual_target is not None: target_label = node_label(actual_target) if target_label is not None: # Add edge from diamond to actual target edge_label = "" if i == 0: edge_label = ' [label="T"]' elif i == 1: edge_label = ' [label="F"]' lines.append(f" n{node.id} -> n{actual_target.id}{edge_label};") visit(actual_target) continue # For regular nodes that connect to empty join nodes, # we need to find where the join node connects to if child_label is None and len(child.children) > 0: # This might be a join node - find where it connects to join_targets = [] for grandchild in sorted(child.children, key=lambda n: n.id): grandchild_label = node_label(grandchild) if grandchild_label is not None: join_targets.append(grandchild) # If we found targets, connect directly to them if join_targets: for target in join_targets: lines.append(f" n{node.id} -> n{target.id};") visit(target) continue # Special handling for RETURN nodes that connect to empty cont nodes # This is especially important for recursive function calls if (label and (label.startswith("RET ") or label.startswith("CALL ")) and child_label is None and len(child.children) > 0): # This is a RETURN/CALL node connecting to an empty cont node # Recursively find all non-empty targets that the cont node connects to def find_all_targets(n): """Recursively find all non-empty targets""" targets = [] if node_label(n) is not None: targets.append(n) else: for grandchild in sorted(n.children, key=lambda n: n.id): targets.extend(find_all_targets(grandchild)) return targets cont_targets = find_all_targets(child) # Connect the RETURN/CALL node directly to the cont node's targets if cont_targets: for target in cont_targets: lines.append(f" n{node.id} -> n{target.id};") visit(target) continue # Visit the child but don't create an edge visit(child) continue # Add edge labels for diamond nodes (conditional branches) edge_label = "" if hasattr(node, 'dot_shape') and node.dot_shape() == "diamond": if i == 0: edge_label = ' [label="T"]' elif i == 1: edge_label = ' [label="F"]' lines.append(f" n{node.id} -> n{child.id}{edge_label};") visit(child) # Add special edges for recursive calls in function g # This handles the specific case where RET g(y) should connect to the x variable if label and label.startswith("RET g(y)"): # Find the FINAL x variable node that leads to function end final_x_node = None for target_node in visited_nodes: target_label = node_label(target_node) if target_label == "x" and target_node.id != node.id: # Check if this x node connects to END g(x) for child in target_node.children: child_label = node_label(child) if child_label and child_label.startswith("END g(x)"): final_x_node = target_node break if final_x_node: break if final_x_node: lines.append(f" n{node.id} -> n{final_x_node.id};") # Start the CFG traversal from the entry node visit(self.in_node) lines.append("}") return "\n".join(lines)