Files
Construction-of-Compilers/Project-02-03-04-05/cfa/ReachedUses.py
2026-03-08 16:33:07 +01:00

97 lines
3.7 KiB
Python

from __future__ import annotations
from typing import TYPE_CHECKING
from cfa.BackwardAnalysis import BackwardAnalysis, Var
if TYPE_CHECKING:
from cfg.CFG import CFG
# A single use-fact: the CFG node at which a variable is used.
# e.g. (42, ("f", "x")) -> variable "x" in function "f" is used at node 42
UseFact = tuple[int, Var]
class ReachedUses(BackwardAnalysis):
def __init__(self, cfg: "CFG") -> None:
# Base populates: uses, defs, _func_scope, _func_parent, _func_params.
super().__init__(cfg)
self.gen: dict[int, set[UseFact]] = {}
self.kill: dict[int, set[UseFact]] = {}
self.in_sets: dict[int, set[UseFact]] = {}
self.out_sets: dict[int, set[UseFact]] = {}
self.all_uses_by_var: dict[Var, set[UseFact]] = {}
self.__init_sets()
self.solve()
# Initialize gen, kill, in, and out sets for all CFG nodes.
def __init_sets(self) -> None:
for node in self.cfg.nodes():
nid = node.id
# GEN(n) = { (n.id, var) | var IN USE(n) }
self.gen[nid] = {(nid, var) for var in self.uses[nid]}
# IN(n) = GEN(n); OUT(n) = empty
self.in_sets[nid] = set(self.gen[nid])
self.out_sets[nid] = set()
# KILL(n) requires knowing all use-facts for a given variable — "at which nodes is variable x used anywhere?"
# all_uses_by_var builds this lookup once upfront: ("f", "x") -> { (42, ("f","x")), (17, ("f","x")) }
for nid, facts in self.gen.items():
for (uid, var) in facts:
self.all_uses_by_var.setdefault(var, set()).add((uid, var))
for node in self.cfg.nodes():
nid = node.id
# KILL(n) = { (uid, var) | var IN DEF(n), (uid, var) IN use_facts_by_var[var] }
# When n defines a variable, it kills all use-facts for that variable, because no use reachable from n
# can have been reached by an earlier definition of the same variable.
kill_n: set[UseFact] = set()
for var in self.defs[nid]:
if var in self.all_uses_by_var:
kill_n |= self.all_uses_by_var[var]
self.kill[nid] = kill_n
# Update the lists until the fixpoint.
def solve(self) -> None:
nodes = list(self.cfg.nodes())
known: set[int] = set(n.id for n in nodes)
# while there are changes do
changes = True
while changes:
changes = False
# for all v in V do
for node in nodes:
nid = node.id
# OUT(n) = UNION IN(s) for all successors s
new_out: set[UseFact] = set()
for child in node.children:
if child.id in known:
new_out |= self.in_sets[child.id]
# IN(n) = GEN(n) UNION (OUT(n) MINUS KILL(n))
new_in: set[UseFact] = self.gen[nid] | (new_out - self.kill[nid])
if new_out != self.out_sets[nid] or new_in != self.in_sets[nid]:
self.out_sets[nid] = new_out
self.in_sets[nid] = new_in
changes = True # there are changes -> loop again
# Return the final reached-uses result
def reached_uses_by_node(self) -> dict[int, list[int]]:
result: dict[int, list[int]] = {}
for node in self.cfg.nodes():
nid = node.id
defs_n = self.defs[nid]
if not defs_n:
continue
reached: set[int] = set()
for (uid, var) in self.out_sets[nid]:
if var in defs_n:
reached.add(uid)
result[nid] = sorted(reached)
return result