from typing import Any, Callable, Set from .CFG_Node import * class CFG: def __init__(self, ast): start = CFG_START() start.dot_style = 'style=filled, color=gray' end = CFG_END() end.dot_style = 'style=filled, color=gray' last = ast.cfa(start, end) if last is not None: last.add_child(end) self.START = start self.END = end self.ast = ast # Remove empty nodes and rewire edges all_nodes = self.nodes() nodes_to_remove = [node for node in all_nodes if node.is_empty()] for node in nodes_to_remove: self.__remove_and_rewire(node) def nodes(self): all_nodes = set() self.__collect_nodes(self.START, all_nodes) return all_nodes def __collect_nodes(self, node, node_set): if node in node_set: return node_set.add(node) for child in node.children: self.__collect_nodes(child, node_set) def __remove_and_rewire(self, node): original_children = list(node.children) for parent in list(node.parents): if node in parent.children: # For diamond nodes, preserve the true and false bodies if isinstance(node, CFG_DIAMOND): targets = [] if len(original_children) >= 1: true_target = self.__first_filled_child(original_children[0]) if true_target: targets.append(true_target) if len(original_children) >= 2: false_target = self.__first_filled_child(original_children[1]) if false_target: targets.append(false_target) # For regular nodes, find all non-empty targets else: targets = [] for child in original_children: target = self.__first_filled_child(child) if target and target not in targets: targets.append(target) # Remove edge from parent to node parent.remove_child(node, propagate=False) # Add edges from parent to targets for target in targets: parent.add_child(target, propagate=False) # Clear the node's connections node.parents.clear() node.children.clear() def __first_filled_child(self, node): if not node.is_empty(): return node # Recursively check children for child in sorted(node.children, key=lambda n: n.id): result = self.__first_filled_child(child) if result is not None: return result return None def to_dot(self) -> str: lines = ["digraph CFG {", ' node [fontname="Helvetica"];'] def emit(node: CFG_Node): label = node.dot_label() shape = node.dot_shape style = node.dot_style style_str = f", {style}" if style else "" lines.append(f' n{node.id} [label="{label}", shape={shape}{style_str}];') for i, child in enumerate(sorted(node.children, key=lambda n: n.id)): edge_label = "" if isinstance(node, CFG_DIAMOND): if i == 0: edge_label = ' [label="T"]' elif i == 1: edge_label = ' [label="F"]' lines.append(f" n{node.id} -> n{child.id}{edge_label};") self.traverse(emit, start=self.START) lines.append("}") return "\n".join(lines) # Reusable traversal function def traverse(self, fn: Callable[[CFG_Node], Any], start: CFG_Node | None = None) -> None: start = start or self.START visited: Set[int] = set() def visit(node: CFG_Node): if node.id in visited: return visited.add(node.id) fn(node) for child in sorted(node.children, key=lambda n: n.id): visit(child) visit(start)