Source code for pyk.kast.pretty

  1from __future__ import annotations
  2
  3import logging
  4from collections.abc import Callable
  5from functools import cached_property
  6from typing import TYPE_CHECKING
  7
  8from ..prelude.kbool import TRUE
  9from .att import Atts, KAtt
 10from .inner import KApply, KAs, KInner, KLabel, KRewrite, KSequence, KSort, KToken, KVariable
 11from .manip import flatten_label, sort_ac_collections, undo_aliases
 12from .outer import (
 13    KBubble,
 14    KClaim,
 15    KContext,
 16    KDefinition,
 17    KFlatModule,
 18    KImport,
 19    KNonTerminal,
 20    KOuter,
 21    KProduction,
 22    KRegexTerminal,
 23    KRequire,
 24    KRule,
 25    KRuleLike,
 26    KSortSynonym,
 27    KSyntaxAssociativity,
 28    KSyntaxLexical,
 29    KSyntaxPriority,
 30    KSyntaxSort,
 31    KTerminal,
 32)
 33
 34if TYPE_CHECKING:
 35    from collections.abc import Iterable
 36    from typing import Any, Final, TypeVar
 37
 38    from .kast import KAst
 39
 40    RL = TypeVar('RL', bound='KRuleLike')
 41
 42_LOGGER: Final = logging.getLogger(__name__)
 43
 44SymbolTable = dict[str, Callable[..., str]]
 45
 46
[docs] 47class PrettyPrinter: 48 definition: KDefinition 49 _extra_unparsing_modules: Iterable[KFlatModule] 50 _patch_symbol_table: Callable[[SymbolTable], None] | None 51 _unalias: bool 52 _sort_collections: bool 53 54 def __init__( 55 self, 56 definition: KDefinition, 57 extra_unparsing_modules: Iterable[KFlatModule] = (), 58 patch_symbol_table: Callable[[SymbolTable], None] | None = None, 59 unalias: bool = True, 60 sort_collections: bool = False, 61 ): 62 self.definition = definition 63 self._extra_unparsing_modules = extra_unparsing_modules 64 self._patch_symbol_table = patch_symbol_table 65 self._unalias = unalias 66 self._sort_collections = sort_collections 67 68 @cached_property 69 def symbol_table(self) -> SymbolTable: 70 symb_table = build_symbol_table( 71 self.definition, 72 extra_modules=self._extra_unparsing_modules, 73 opinionated=True, 74 ) 75 if self._patch_symbol_table is not None: 76 self._patch_symbol_table(symb_table) 77 return symb_table 78
[docs] 79 def print(self, kast: KAst) -> str: 80 """Print out KAST terms/outer syntax. 81 - Input: KAST term. 82 - Output: Best-effort string representation of KAST term. 83 """ 84 _LOGGER.debug(f'Unparsing: {kast}') 85 if type(kast) is KAtt: 86 return self._print_katt(kast) 87 if type(kast) is KSort: 88 return self._print_ksort(kast) 89 if type(kast) is KLabel: 90 return self._print_klabel(kast) 91 elif isinstance(kast, KOuter): 92 return self._print_kouter(kast) 93 elif isinstance(kast, KInner): 94 if self._unalias: 95 kast = undo_aliases(self.definition, kast) 96 if self._sort_collections: 97 kast = sort_ac_collections(kast) 98 return self._print_kinner(kast) 99 raise AssertionError(f'Error unparsing: {kast}')
100 101 def _print_kouter(self, kast: KOuter) -> str: 102 match kast: 103 case KTerminal(): 104 return self._print_kterminal(kast) 105 case KRegexTerminal(): 106 return self._print_kregexterminal(kast) 107 case KNonTerminal(): 108 return self._print_knonterminal(kast) 109 case KProduction(): 110 return self._print_kproduction(kast) 111 case KSyntaxSort(): 112 return self._print_ksyntaxsort(kast) 113 case KSortSynonym(): 114 return self._print_ksortsynonym(kast) 115 case KSyntaxLexical(): 116 return self._print_ksyntaxlexical(kast) 117 case KSyntaxAssociativity(): 118 return self._print_ksyntaxassociativity(kast) 119 case KSyntaxPriority(): 120 return self._print_ksyntaxpriority(kast) 121 case KBubble(): 122 return self._print_kbubble(kast) 123 case KRule(): 124 return self._print_krule(kast) 125 case KClaim(): 126 return self._print_kclaim(kast) 127 case KContext(): 128 return self._print_kcontext(kast) 129 case KImport(): 130 return self._print_kimport(kast) 131 case KFlatModule(): 132 return self._print_kflatmodule(kast) 133 case KRequire(): 134 return self._print_krequire(kast) 135 case KDefinition(): 136 return self._print_kdefinition(kast) 137 case _: 138 raise AssertionError(f'Error unparsing: {kast}') 139 140 def _print_kinner(self, kast: KInner) -> str: 141 match kast: 142 case KVariable(): 143 return self._print_kvariable(kast) 144 case KToken(): 145 return self._print_ktoken(kast) 146 case KApply(): 147 return self._print_kapply(kast) 148 case KAs(): 149 return self._print_kas(kast) 150 case KRewrite(): 151 return self._print_krewrite(kast) 152 case KSequence(): 153 return self._print_ksequence(kast) 154 case _: 155 raise AssertionError(f'Error unparsing: {kast}') 156 157 def _print_ksort(self, ksort: KSort) -> str: 158 return ksort.name 159 160 def _print_klabel(self, klabel: KLabel) -> str: 161 return klabel.name 162 163 def _print_kvariable(self, kvariable: KVariable) -> str: 164 sort = kvariable.sort 165 if not sort: 166 return kvariable.name 167 return kvariable.name + ':' + sort.name 168 169 def _print_ktoken(self, ktoken: KToken) -> str: 170 return ktoken.token 171 172 def _print_kapply(self, kapply: KApply) -> str: 173 label = kapply.label.name 174 args = kapply.args 175 unparsed_args = [self._print_kinner(arg) for arg in args] 176 if kapply.is_cell: 177 cell_contents = '\n'.join(unparsed_args).rstrip() 178 cell_str = label + '\n' + indent(cell_contents) + '\n</' + label[1:] 179 return cell_str.rstrip() 180 unparser = self._applied_label_str(label) if label not in self.symbol_table else self.symbol_table[label] 181 return unparser(*unparsed_args) 182 183 def _print_kas(self, kas: KAs) -> str: 184 pattern_str = self._print_kinner(kas.pattern) 185 alias_str = self._print_kinner(kas.alias) 186 return pattern_str + ' #as ' + alias_str 187 188 def _print_krewrite(self, krewrite: KRewrite) -> str: 189 lhs_str = self._print_kinner(krewrite.lhs) 190 rhs_str = self._print_kinner(krewrite.rhs) 191 return '( ' + lhs_str + ' => ' + rhs_str + ' )' 192 193 def _print_ksequence(self, ksequence: KSequence) -> str: 194 if ksequence.arity == 0: 195 # TODO: Would be nice to say `return self._print_kinner(EMPTY_K)` 196 return '.K' 197 if ksequence.arity == 1: 198 return self._print_kinner(ksequence.items[0]) + ' ~> .K' 199 unparsed_k_seq = '\n~> '.join([self._print_kinner(item) for item in ksequence.items[0:-1]]) 200 if ksequence.items[-1] == KToken('...', KSort('K')): 201 unparsed_k_seq = unparsed_k_seq + '\n' + self._print_kinner(KToken('...', KSort('K'))) 202 else: 203 unparsed_k_seq = unparsed_k_seq + '\n~> ' + self._print_kinner(ksequence.items[-1]) 204 return unparsed_k_seq 205 206 def _print_kterminal(self, kterminal: KTerminal) -> str: 207 return '"' + kterminal.value + '"' 208 209 def _print_kregexterminal(self, kregexterminal: KRegexTerminal) -> str: 210 return 'r"' + kregexterminal.regex + '"' 211 212 def _print_knonterminal(self, knonterminal: KNonTerminal) -> str: 213 return self.print(knonterminal.sort) 214 215 def _print_kproduction(self, kproduction: KProduction) -> str: 216 if Atts.KLABEL not in kproduction.att and kproduction.klabel: 217 kproduction = kproduction.update_atts([Atts.KLABEL(kproduction.klabel.name)]) 218 syntax_str = 'syntax ' + self.print(kproduction.sort) 219 if kproduction.items: 220 syntax_str += ' ::= ' + ' '.join([self._print_kouter(pi) for pi in kproduction.items]) 221 att_str = self.print(kproduction.att) 222 if att_str: 223 syntax_str += ' ' + att_str 224 return syntax_str 225 226 def _print_ksyntaxsort(self, ksyntaxsort: KSyntaxSort) -> str: 227 sort_str = self.print(ksyntaxsort.sort) 228 att_str = self.print(ksyntaxsort.att) 229 return 'syntax ' + sort_str + ' ' + att_str 230 231 def _print_ksortsynonym(self, ksortsynonym: KSortSynonym) -> str: 232 new_sort_str = self.print(ksortsynonym.new_sort) 233 old_sort_str = self.print(ksortsynonym.old_sort) 234 att_str = self.print(ksortsynonym.att) 235 return 'syntax ' + new_sort_str + ' = ' + old_sort_str + ' ' + att_str 236 237 def _print_ksyntaxlexical(self, ksyntaxlexical: KSyntaxLexical) -> str: 238 name_str = ksyntaxlexical.name 239 regex_str = ksyntaxlexical.regex 240 att_str = self.print(ksyntaxlexical.att) 241 # todo: proper escaping 242 return 'syntax lexical ' + name_str + ' = r"' + regex_str + '" ' + att_str 243 244 def _print_ksyntaxassociativity(self, ksyntaxassociativity: KSyntaxAssociativity) -> str: 245 assoc_str = ksyntaxassociativity.assoc.value 246 tags_str = ' '.join(ksyntaxassociativity.tags) 247 att_str = self.print(ksyntaxassociativity.att) 248 return 'syntax associativity ' + assoc_str + ' ' + tags_str + ' ' + att_str 249 250 def _print_ksyntaxpriority(self, ksyntaxpriority: KSyntaxPriority) -> str: 251 priorities_str = ' > '.join([' '.join(group) for group in ksyntaxpriority.priorities]) 252 att_str = self.print(ksyntaxpriority.att) 253 return 'syntax priority ' + priorities_str + ' ' + att_str 254 255 def _print_kbubble(self, kbubble: KBubble) -> str: 256 body = '// KBubble(' + kbubble.sentence_type + ', ' + kbubble.content + ')' 257 att_str = self.print(kbubble.att) 258 return body + ' ' + att_str 259 260 def _print_krule(self, kterm: KRule) -> str: 261 body = '\n '.join(self.print(kterm.body).split('\n')) 262 rule_str = 'rule ' 263 if Atts.LABEL in kterm.att: 264 rule_str = rule_str + '[' + kterm.att[Atts.LABEL] + ']:' 265 rule_str = rule_str + ' ' + body 266 atts_str = self.print(kterm.att) 267 if kterm.requires != TRUE: 268 requires_str = 'requires ' + '\n '.join(self._print_kast_bool(kterm.requires).split('\n')) 269 rule_str = rule_str + '\n ' + requires_str 270 if kterm.ensures != TRUE: 271 ensures_str = 'ensures ' + '\n '.join(self._print_kast_bool(kterm.ensures).split('\n')) 272 rule_str = rule_str + '\n ' + ensures_str 273 return rule_str + '\n ' + atts_str 274 275 def _print_kclaim(self, kterm: KClaim) -> str: 276 body = '\n '.join(self.print(kterm.body).split('\n')) 277 rule_str = 'claim ' 278 if Atts.LABEL in kterm.att: 279 rule_str = rule_str + '[' + kterm.att[Atts.LABEL] + ']:' 280 rule_str = rule_str + ' ' + body 281 atts_str = self.print(kterm.att) 282 if kterm.requires != TRUE: 283 requires_str = 'requires ' + '\n '.join(self._print_kast_bool(kterm.requires).split('\n')) 284 rule_str = rule_str + '\n ' + requires_str 285 if kterm.ensures != TRUE: 286 ensures_str = 'ensures ' + '\n '.join(self._print_kast_bool(kterm.ensures).split('\n')) 287 rule_str = rule_str + '\n ' + ensures_str 288 return rule_str + '\n ' + atts_str 289 290 def _print_kcontext(self, kcontext: KContext) -> str: 291 body = indent(self.print(kcontext.body)) 292 context_str = 'context alias ' + body 293 requires_str = '' 294 atts_str = self.print(kcontext.att) 295 if kcontext.requires != TRUE: 296 requires_str = self.print(kcontext.requires) 297 requires_str = 'requires ' + indent(requires_str) 298 return context_str + '\n ' + requires_str + '\n ' + atts_str 299 300 def _print_katt(self, katt: KAtt) -> str: 301 return katt.pretty 302 303 def _print_kimport(self, kimport: KImport) -> str: 304 return ' '.join(['imports', ('public' if kimport.public else 'private'), kimport.name]) 305 306 def _print_kflatmodule(self, kflatmodule: KFlatModule) -> str: 307 name = kflatmodule.name 308 imports = '\n'.join([self._print_kouter(kimport) for kimport in kflatmodule.imports]) 309 sentences = '\n\n'.join([self._print_kouter(sentence) for sentence in kflatmodule.sentences]) 310 contents = imports + '\n\n' + sentences 311 return 'module ' + name + '\n ' + '\n '.join(contents.split('\n')) + '\n\nendmodule' 312 313 def _print_krequire(self, krequire: KRequire) -> str: 314 return 'requires "' + krequire.require + '"' 315 316 def _print_kdefinition(self, kdefinition: KDefinition) -> str: 317 requires = '\n'.join([self._print_kouter(require) for require in kdefinition.requires]) 318 modules = '\n\n'.join([self._print_kouter(module) for module in kdefinition.all_modules]) 319 return requires + '\n\n' + modules 320 321 def _print_kast_bool(self, kast: KAst) -> str: 322 """Print out KAST requires/ensures clause. 323 324 - Input: KAST Bool for requires/ensures clause. 325 - Output: Best-effort string representation of KAST term. 326 """ 327 _LOGGER.debug(f'_print_kast_bool: {kast}') 328 if type(kast) is KApply and kast.label.name in ['_andBool_', '_orBool_']: 329 clauses = [self._print_kast_bool(c) for c in flatten_label(kast.label.name, kast)] 330 head = kast.label.name.replace('_', ' ') 331 if head == ' orBool ': 332 head = ' orBool ' 333 separator = ' ' * (len(head) - 7) 334 spacer = ' ' * len(head) 335 336 def join_sep(s: str) -> str: 337 return ('\n' + separator).join(s.split('\n')) 338 339 clauses = ( 340 ['( ' + join_sep(clauses[0])] 341 + [head + '( ' + join_sep(c) for c in clauses[1:]] 342 + [spacer + (')' * len(clauses))] 343 ) 344 return '\n'.join(clauses) 345 else: 346 return self.print(kast) 347 348 def _applied_label_str(self, symbol: str) -> Callable[..., str]: 349 return lambda *args: symbol + ' ( ' + ' , '.join(args) + ' )'
350 351
[docs] 352def build_symbol_table( 353 definition: KDefinition, 354 extra_modules: Iterable[KFlatModule] = (), 355 opinionated: bool = False, 356) -> SymbolTable: 357 """Build the unparsing symbol table given a JSON encoded definition. 358 359 - Input: JSON encoded K definition. 360 - Return: Python dictionary mapping klabels to automatically generated unparsers. 361 """ 362 symbol_table = {} 363 all_modules = list(definition.all_modules) + ([] if extra_modules is None else list(extra_modules)) 364 for module in all_modules: 365 for prod in module.syntax_productions: 366 assert prod.klabel 367 label = prod.klabel.name 368 unparser = unparser_for_production(prod) 369 370 symbol_table[label] = unparser 371 if Atts.SYMBOL in prod.att and Atts.KLABEL in prod.att: 372 symbol_table[prod.att[Atts.KLABEL]] = unparser 373 374 if opinionated: 375 symbol_table['#And'] = lambda c1, c2: c1 + '\n#And ' + c2 376 symbol_table['#Or'] = lambda c1, c2: c1 + '\n#Or\n' + indent(c2, size=4) 377 378 return symbol_table
379 380
[docs] 381def unparser_for_production(prod: KProduction) -> Callable[..., str]: 382 def _unparser(*args: Any) -> str: 383 index = 0 384 result = [] 385 num_nonterm = len([item for item in prod.items if type(item) is KNonTerminal]) 386 num_named_nonterm = len([item for item in prod.items if type(item) is KNonTerminal and item.name != None]) 387 for item in prod.items: 388 if type(item) is KTerminal: 389 result.append(item.value) 390 elif type(item) is KNonTerminal and index < len(args): 391 if num_nonterm == num_named_nonterm: 392 if index == 0: 393 result.append('...') 394 result.append(f'{item.name}:') 395 result.append(args[index]) 396 index += 1 397 return ' '.join(result) 398 399 return _unparser
400 401
[docs] 402def indent(text: str, size: int = 2) -> str: 403 return '\n'.join([(' ' * size) + line for line in text.split('\n')])
404 405
[docs] 406def paren(printer: Callable[..., str]) -> Callable[..., str]: 407 return lambda *args: '( ' + printer(*args) + ' )'
408 409
[docs] 410def assoc_with_unit(assoc_join: str, unit: str) -> Callable[..., str]: 411 def _assoc_with_unit(*args: str) -> str: 412 return assoc_join.join(arg for arg in args if arg != unit) 413 414 return _assoc_with_unit