Files
Construction-of-Compilers/Project-02-03-04/cfg/CFG.py
2026-01-22 20:26:41 +01:00

197 lines
7.7 KiB
Python

from typing import Any
from .CFG_Node import *
class CFG:
def __init__(self, in_node: CFG_Node, out_node: CFG_Node, ast=None):
self.in_node = in_node
self.out_node = out_node
self.ast = ast
# If AST is provided, filter the graph by removing empty nodes
if ast is not None:
self._filter_graph()
def _filter_graph(self):
"""
Filter the CFG by removing empty nodes and rewiring edges.
This should be done once during construction, not during to_dot().
"""
# Collect all nodes in the graph
all_nodes = set()
self._collect_nodes(self.in_node, all_nodes)
# Identify nodes to remove
nodes_to_remove = [node for node in all_nodes if self._should_remove_node(node)]
# Remove nodes and rewrite edges
for node in nodes_to_remove:
self._remove_node_and_rewire(node)
def _collect_nodes(self, node, node_set):
"""Recursively collect all nodes in the graph"""
if node in node_set:
return
node_set.add(node)
for child in node.children:
self._collect_nodes(child, node_set)
def _should_remove_node(self, node):
"""Determine if a node should be removed from the graph"""
# Remove empty nodes (nodes with no meaningful content)
# Check for both None and "None" string
if hasattr(node, 'label') and ((node.label is None) or (node.label == "None")):
# Nodes with AST nodes should NOT be removed - they will get labels from AST
if node.ast_node is not None:
return False
# Also keep global START nodes (they have label=None but should be shown)
if hasattr(node, 'dot_label') and node.dot_label() == "START":
return False
# Remove nodes that have no AST and no meaningful label
return True
# Remove global END nodes (those without function names)
if hasattr(node, 'dot_label'):
if node.dot_label() in ["END"]:
# Keep function-specific END nodes, skip global ones
if hasattr(node, 'label') and node.label and '(' in node.label and ')' in node.label:
return False
else:
return True
return False
def _remove_node_and_rewire(self, node):
"""Remove a node from the graph and rewire edges to bypass it"""
# Store original children before modification
original_children = list(node.children)
# For each parent, rewire edges to bypass this node
for parent in list(node.parents):
if node in parent.children:
# Find appropriate targets based on node type
if hasattr(node, 'dot_shape') and node.dot_shape() == "diamond":
# For diamond nodes, preserve T/F branches
targets = []
if len(original_children) >= 1:
true_target = self._find_first_non_empty_child(original_children[0])
if true_target:
targets.append(true_target)
if len(original_children) >= 2:
false_target = self._find_first_non_empty_child(original_children[1])
if false_target:
targets.append(false_target)
else:
# For regular nodes, find all non-empty targets
targets = []
for child in original_children:
target = self._find_first_non_empty_child(child)
if target and target not in targets:
targets.append(target)
# Remove edge from parent to node
parent.remove_child(node, propagate=False)
# Add edges from parent to targets
for target in targets:
parent.add_child(target, propagate=False)
# Clear the node's connections
node.parents.clear()
node.children.clear()
def _find_first_non_empty_child(self, node):
"""Find the first non-empty descendant of a node"""
if not self._should_remove_node(node):
return node
# Recursively check children
for child in sorted(node.children, key=lambda n: n.id):
result = self._find_first_non_empty_child(child)
if result is not None:
return result
return None
def to_dot(self) -> str:
"""
Convert the CFG to DOT format.
This method should ONLY handle formatting, not graph modifications.
All graph filtering and modifications should be done in the constructor.
"""
visited = set()
lines = ["digraph CFG {"]
lines.append(' node [fontname="Helvetica"];')
def node_label(node: CFG_Node) -> str | None | Any:
# Use custom label if available
if hasattr(node, 'label') and node.label:
# Remove node ID from label for certain node types
if isinstance(node, (CFG_START, CFG_END, CFG_CALL, CFG_RETURN)):
return node.label
else:
return node.label
# Base label from the node
base = node.dot_label() if hasattr(node, "dot_label") else ""
# Semantic label from AST
if node.ast_node is not None:
semantic = str(node.ast_node)
label_content = f"{base}\n{semantic}" if base else semantic
return label_content
return base if base else None
def node_shape(node: CFG_Node) -> str:
return node.dot_shape() if hasattr(node, "dot_shape") else "box"
def node_style(node: CFG_Node) -> str:
# Add styling for special node types
styles = []
if hasattr(node, 'label') and node.label:
if node.label.startswith('CALL') or node.label.startswith('RET'):
styles.append('style=filled')
styles.append('color=orange')
elif node.label.startswith('START') or node.label.startswith('END'):
styles.append('style=filled')
styles.append('color=green')
return ', '.join(styles) if styles else ''
def visit(node: CFG_Node):
if node.id in visited:
return
visited.add(node.id)
label = node_label(node)
if label is None:
# This shouldn't happen if the constructor did its job properly
return
shape = node_shape(node)
style = node_style(node)
style_str = f", {style}" if style else ""
lines.append(
f' n{node.id} [label="{label}", shape={shape}{style_str}];'
)
# Add edges to children
for i, child in enumerate(sorted(node.children, key=lambda n: n.id)):
# Add edge labels for diamond nodes (conditional branches)
edge_label = ""
if hasattr(node, 'dot_shape') and node.dot_shape() == "diamond":
if i == 0:
edge_label = ' [label="T"]'
elif i == 1:
edge_label = ' [label="F"]'
lines.append(f" n{node.id} -> n{child.id}{edge_label};")
visit(child)
# Start the CFG traversal from the entry node
visit(self.in_node)
lines.append("}")
return "\n".join(lines)