97 lines
3.7 KiB
Python
97 lines
3.7 KiB
Python
from __future__ import annotations
|
|
from typing import TYPE_CHECKING
|
|
from cfa.BackwardAnalysis import BackwardAnalysis, Var
|
|
|
|
if TYPE_CHECKING:
|
|
from cfg.CFG import CFG
|
|
|
|
# A single use-fact: the CFG node at which a variable is used.
|
|
# e.g. (42, ("f", "x")) -> variable "x" in function "f" is used at node 42
|
|
UseFact = tuple[int, Var]
|
|
|
|
class ReachedUses(BackwardAnalysis):
|
|
def __init__(self, cfg: "CFG") -> None:
|
|
# Base populates: uses, defs, _func_scope, _func_parent, _func_params.
|
|
super().__init__(cfg)
|
|
|
|
self.gen: dict[int, set[UseFact]] = {}
|
|
self.kill: dict[int, set[UseFact]] = {}
|
|
self.in_sets: dict[int, set[UseFact]] = {}
|
|
self.out_sets: dict[int, set[UseFact]] = {}
|
|
self.all_uses_by_var: dict[Var, set[UseFact]] = {}
|
|
|
|
self.__init_sets()
|
|
self.solve()
|
|
|
|
# Initialize gen, kill, in, and out sets for all CFG nodes.
|
|
def __init_sets(self) -> None:
|
|
for node in self.cfg.nodes():
|
|
nid = node.id
|
|
|
|
# GEN(n) = { (n.id, var) | var IN USE(n) }
|
|
self.gen[nid] = {(nid, var) for var in self.uses[nid]}
|
|
|
|
# IN(n) = GEN(n); OUT(n) = empty
|
|
self.in_sets[nid] = set(self.gen[nid])
|
|
self.out_sets[nid] = set()
|
|
|
|
# KILL(n) requires knowing all use-facts for a given variable — "at which nodes is variable x used anywhere?"
|
|
# all_uses_by_var builds this lookup once upfront: ("f", "x") -> { (42, ("f","x")), (17, ("f","x")) }
|
|
for nid, facts in self.gen.items():
|
|
for (uid, var) in facts:
|
|
self.all_uses_by_var.setdefault(var, set()).add((uid, var))
|
|
|
|
for node in self.cfg.nodes():
|
|
nid = node.id
|
|
|
|
# KILL(n) = { (uid, var) | var IN DEF(n), (uid, var) IN use_facts_by_var[var] }
|
|
# When n defines a variable, it kills all use-facts for that variable, because no use reachable from n
|
|
# can have been reached by an earlier definition of the same variable.
|
|
kill_n: set[UseFact] = set()
|
|
for var in self.defs[nid]:
|
|
if var in self.all_uses_by_var:
|
|
kill_n |= self.all_uses_by_var[var]
|
|
self.kill[nid] = kill_n
|
|
|
|
# Update the lists until the fixpoint.
|
|
def solve(self) -> None:
|
|
nodes = list(self.cfg.nodes())
|
|
known: set[int] = set(n.id for n in nodes)
|
|
|
|
# while there are changes do
|
|
changes = True
|
|
while changes:
|
|
changes = False
|
|
|
|
# for all v in V do
|
|
for node in nodes:
|
|
nid = node.id
|
|
|
|
# OUT(n) = UNION IN(s) for all successors s
|
|
new_out: set[UseFact] = set()
|
|
for child in node.children:
|
|
if child.id in known:
|
|
new_out |= self.in_sets[child.id]
|
|
|
|
# IN(n) = GEN(n) UNION (OUT(n) MINUS KILL(n))
|
|
new_in: set[UseFact] = self.gen[nid] | (new_out - self.kill[nid])
|
|
|
|
if new_out != self.out_sets[nid] or new_in != self.in_sets[nid]:
|
|
self.out_sets[nid] = new_out
|
|
self.in_sets[nid] = new_in
|
|
changes = True # there are changes -> loop again
|
|
|
|
# Return the final reached-uses result
|
|
def reached_uses_by_node(self) -> dict[int, list[int]]:
|
|
result: dict[int, list[int]] = {}
|
|
for node in self.cfg.nodes():
|
|
nid = node.id
|
|
defs_n = self.defs[nid]
|
|
if not defs_n:
|
|
continue
|
|
reached: set[int] = set()
|
|
for (uid, var) in self.out_sets[nid]:
|
|
if var in defs_n:
|
|
reached.add(uid)
|
|
result[nid] = sorted(reached)
|
|
return result |