from typing import Any from .CFG_Node import * class CFG: def __init__(self, in_node: CFG_Node, out_node: CFG_Node, ast=None): self.in_node = in_node self.out_node = out_node self.ast = ast # If AST is provided, filter the graph by removing empty nodes if ast is not None: self._filter_graph() def _filter_graph(self): """ Filter the CFG by removing empty nodes and rewiring edges. This should be done once during construction, not during to_dot(). """ # Collect all nodes in the graph all_nodes = set() self._collect_nodes(self.in_node, all_nodes) # Identify nodes to remove nodes_to_remove = [node for node in all_nodes if self._should_remove_node(node)] # Remove nodes and rewrite edges for node in nodes_to_remove: self._remove_node_and_rewire(node) def _collect_nodes(self, node, node_set): """Recursively collect all nodes in the graph""" if node in node_set: return node_set.add(node) for child in node.children: self._collect_nodes(child, node_set) def _should_remove_node(self, node): """Determine if a node should be removed from the graph""" # Remove empty nodes (nodes with no meaningful content) # Check for both None and "None" string if hasattr(node, 'label') and ((node.label is None) or (node.label == "None")): # Nodes with AST nodes should NOT be removed - they will get labels from AST if node.ast_node is not None: return False # Also keep global START nodes (they have label=None but should be shown) if hasattr(node, 'dot_label') and node.dot_label() == "START": return False # Remove nodes that have no AST and no meaningful label return True # Remove global END nodes (those without function names) if hasattr(node, 'dot_label'): if node.dot_label() in ["END"]: # Keep function-specific END nodes, skip global ones if hasattr(node, 'label') and node.label and '(' in node.label and ')' in node.label: return False else: return True return False def _remove_node_and_rewire(self, node): """Remove a node from the graph and rewire edges to bypass it""" # Store original children before modification original_children = list(node.children) # For each parent, rewire edges to bypass this node for parent in list(node.parents): if node in parent.children: # Find appropriate targets based on node type if hasattr(node, 'dot_shape') and node.dot_shape() == "diamond": # For diamond nodes, preserve T/F branches targets = [] if len(original_children) >= 1: true_target = self._find_first_non_empty_child(original_children[0]) if true_target: targets.append(true_target) if len(original_children) >= 2: false_target = self._find_first_non_empty_child(original_children[1]) if false_target: targets.append(false_target) else: # For regular nodes, find all non-empty targets targets = [] for child in original_children: target = self._find_first_non_empty_child(child) if target and target not in targets: targets.append(target) # Remove edge from parent to node parent.remove_child(node, propagate=False) # Add edges from parent to targets for target in targets: parent.add_child(target, propagate=False) # Clear the node's connections node.parents.clear() node.children.clear() def _find_first_non_empty_child(self, node): """Find the first non-empty descendant of a node""" if not self._should_remove_node(node): return node # Recursively check children for child in sorted(node.children, key=lambda n: n.id): result = self._find_first_non_empty_child(child) if result is not None: return result return None def to_dot(self) -> str: """ Convert the CFG to DOT format. This method should ONLY handle formatting, not graph modifications. All graph filtering and modifications should be done in the constructor. """ visited = set() lines = ["digraph CFG {"] lines.append(' node [fontname="Helvetica"];') def node_label(node: CFG_Node) -> str | None | Any: # Use custom label if available if hasattr(node, 'label') and node.label: # Remove node ID from label for certain node types if isinstance(node, (CFG_START, CFG_END, CFG_CALL, CFG_RETURN)): return node.label else: return node.label # Base label from the node base = node.dot_label() if hasattr(node, "dot_label") else "" # Semantic label from AST if node.ast_node is not None: semantic = str(node.ast_node) label_content = f"{base}\n{semantic}" if base else semantic return label_content return base if base else None def node_shape(node: CFG_Node) -> str: return node.dot_shape() if hasattr(node, "dot_shape") else "box" def node_style(node: CFG_Node) -> str: # Add styling for special node types styles = [] if hasattr(node, 'label') and node.label: if node.label.startswith('CALL') or node.label.startswith('RET'): styles.append('style=filled') styles.append('color=orange') elif node.label.startswith('START') or node.label.startswith('END'): styles.append('style=filled') styles.append('color=green') return ', '.join(styles) if styles else '' def visit(node: CFG_Node): if node.id in visited: return visited.add(node.id) label = node_label(node) if label is None: # This shouldn't happen if the constructor did its job properly return shape = node_shape(node) style = node_style(node) style_str = f", {style}" if style else "" lines.append( f' n{node.id} [label="{label}", shape={shape}{style_str}];' ) # Add edges to children for i, child in enumerate(sorted(node.children, key=lambda n: n.id)): # Add edge labels for diamond nodes (conditional branches) edge_label = "" if hasattr(node, 'dot_shape') and node.dot_shape() == "diamond": if i == 0: edge_label = ' [label="T"]' elif i == 1: edge_label = ' [label="F"]' lines.append(f" n{node.id} -> n{child.id}{edge_label};") visit(child) # Start the CFG traversal from the entry node visit(self.in_node) lines.append("}") return "\n".join(lines)