1from __future__ import annotations
2
3import logging
4from collections.abc import Callable
5from functools import cached_property
6from typing import TYPE_CHECKING
7
8from ..prelude.kbool import TRUE
9from .att import Atts, KAtt
10from .inner import KApply, KAs, KInner, KLabel, KRewrite, KSequence, KSort, KToken, KVariable
11from .manip import flatten_label, sort_ac_collections, undo_aliases
12from .outer import (
13 KBubble,
14 KClaim,
15 KContext,
16 KDefinition,
17 KFlatModule,
18 KImport,
19 KNonTerminal,
20 KOuter,
21 KProduction,
22 KRegexTerminal,
23 KRequire,
24 KRule,
25 KRuleLike,
26 KSortSynonym,
27 KSyntaxAssociativity,
28 KSyntaxLexical,
29 KSyntaxPriority,
30 KSyntaxSort,
31 KTerminal,
32)
33
34if TYPE_CHECKING:
35 from collections.abc import Iterable
36 from typing import Any, Final, TypeVar
37
38 from .kast import KAst
39
40 RL = TypeVar('RL', bound='KRuleLike')
41
42_LOGGER: Final = logging.getLogger(__name__)
43
44SymbolTable = dict[str, Callable[..., str]]
45
46
[docs]
47class PrettyPrinter:
48 definition: KDefinition
49 _extra_unparsing_modules: Iterable[KFlatModule]
50 _patch_symbol_table: Callable[[SymbolTable], None] | None
51 _unalias: bool
52 _sort_collections: bool
53
54 def __init__(
55 self,
56 definition: KDefinition,
57 extra_unparsing_modules: Iterable[KFlatModule] = (),
58 patch_symbol_table: Callable[[SymbolTable], None] | None = None,
59 unalias: bool = True,
60 sort_collections: bool = False,
61 ):
62 self.definition = definition
63 self._extra_unparsing_modules = extra_unparsing_modules
64 self._patch_symbol_table = patch_symbol_table
65 self._unalias = unalias
66 self._sort_collections = sort_collections
67
68 @cached_property
69 def symbol_table(self) -> SymbolTable:
70 symb_table = build_symbol_table(
71 self.definition,
72 extra_modules=self._extra_unparsing_modules,
73 opinionated=True,
74 )
75 if self._patch_symbol_table is not None:
76 self._patch_symbol_table(symb_table)
77 return symb_table
78
[docs]
79 def print(self, kast: KAst) -> str:
80 """Print out KAST terms/outer syntax.
81 - Input: KAST term.
82 - Output: Best-effort string representation of KAST term.
83 """
84 _LOGGER.debug(f'Unparsing: {kast}')
85 if type(kast) is KAtt:
86 return self._print_katt(kast)
87 if type(kast) is KSort:
88 return self._print_ksort(kast)
89 if type(kast) is KLabel:
90 return self._print_klabel(kast)
91 elif isinstance(kast, KOuter):
92 return self._print_kouter(kast)
93 elif isinstance(kast, KInner):
94 if self._unalias:
95 kast = undo_aliases(self.definition, kast)
96 if self._sort_collections:
97 kast = sort_ac_collections(kast)
98 return self._print_kinner(kast)
99 raise AssertionError(f'Error unparsing: {kast}')
100
101 def _print_kouter(self, kast: KOuter) -> str:
102 match kast:
103 case KTerminal():
104 return self._print_kterminal(kast)
105 case KRegexTerminal():
106 return self._print_kregexterminal(kast)
107 case KNonTerminal():
108 return self._print_knonterminal(kast)
109 case KProduction():
110 return self._print_kproduction(kast)
111 case KSyntaxSort():
112 return self._print_ksyntaxsort(kast)
113 case KSortSynonym():
114 return self._print_ksortsynonym(kast)
115 case KSyntaxLexical():
116 return self._print_ksyntaxlexical(kast)
117 case KSyntaxAssociativity():
118 return self._print_ksyntaxassociativity(kast)
119 case KSyntaxPriority():
120 return self._print_ksyntaxpriority(kast)
121 case KBubble():
122 return self._print_kbubble(kast)
123 case KRule():
124 return self._print_krule(kast)
125 case KClaim():
126 return self._print_kclaim(kast)
127 case KContext():
128 return self._print_kcontext(kast)
129 case KImport():
130 return self._print_kimport(kast)
131 case KFlatModule():
132 return self._print_kflatmodule(kast)
133 case KRequire():
134 return self._print_krequire(kast)
135 case KDefinition():
136 return self._print_kdefinition(kast)
137 case _:
138 raise AssertionError(f'Error unparsing: {kast}')
139
140 def _print_kinner(self, kast: KInner) -> str:
141 match kast:
142 case KVariable():
143 return self._print_kvariable(kast)
144 case KToken():
145 return self._print_ktoken(kast)
146 case KApply():
147 return self._print_kapply(kast)
148 case KAs():
149 return self._print_kas(kast)
150 case KRewrite():
151 return self._print_krewrite(kast)
152 case KSequence():
153 return self._print_ksequence(kast)
154 case _:
155 raise AssertionError(f'Error unparsing: {kast}')
156
157 def _print_ksort(self, ksort: KSort) -> str:
158 return ksort.name
159
160 def _print_klabel(self, klabel: KLabel) -> str:
161 return klabel.name
162
163 def _print_kvariable(self, kvariable: KVariable) -> str:
164 sort = kvariable.sort
165 if not sort:
166 return kvariable.name
167 return kvariable.name + ':' + sort.name
168
169 def _print_ktoken(self, ktoken: KToken) -> str:
170 return ktoken.token
171
172 def _print_kapply(self, kapply: KApply) -> str:
173 label = kapply.label.name
174 args = kapply.args
175 unparsed_args = [self._print_kinner(arg) for arg in args]
176 if kapply.is_cell:
177 cell_contents = '\n'.join(unparsed_args).rstrip()
178 cell_str = label + '\n' + indent(cell_contents) + '\n</' + label[1:]
179 return cell_str.rstrip()
180 unparser = self._applied_label_str(label) if label not in self.symbol_table else self.symbol_table[label]
181 return unparser(*unparsed_args)
182
183 def _print_kas(self, kas: KAs) -> str:
184 pattern_str = self._print_kinner(kas.pattern)
185 alias_str = self._print_kinner(kas.alias)
186 return pattern_str + ' #as ' + alias_str
187
188 def _print_krewrite(self, krewrite: KRewrite) -> str:
189 lhs_str = self._print_kinner(krewrite.lhs)
190 rhs_str = self._print_kinner(krewrite.rhs)
191 return '( ' + lhs_str + ' => ' + rhs_str + ' )'
192
193 def _print_ksequence(self, ksequence: KSequence) -> str:
194 if ksequence.arity == 0:
195 # TODO: Would be nice to say `return self._print_kinner(EMPTY_K)`
196 return '.K'
197 if ksequence.arity == 1:
198 return self._print_kinner(ksequence.items[0]) + ' ~> .K'
199 unparsed_k_seq = '\n~> '.join([self._print_kinner(item) for item in ksequence.items[0:-1]])
200 if ksequence.items[-1] == KToken('...', KSort('K')):
201 unparsed_k_seq = unparsed_k_seq + '\n' + self._print_kinner(KToken('...', KSort('K')))
202 else:
203 unparsed_k_seq = unparsed_k_seq + '\n~> ' + self._print_kinner(ksequence.items[-1])
204 return unparsed_k_seq
205
206 def _print_kterminal(self, kterminal: KTerminal) -> str:
207 return '"' + kterminal.value + '"'
208
209 def _print_kregexterminal(self, kregexterminal: KRegexTerminal) -> str:
210 return 'r"' + kregexterminal.regex + '"'
211
212 def _print_knonterminal(self, knonterminal: KNonTerminal) -> str:
213 return self.print(knonterminal.sort)
214
215 def _print_kproduction(self, kproduction: KProduction) -> str:
216 if Atts.KLABEL not in kproduction.att and kproduction.klabel:
217 kproduction = kproduction.update_atts([Atts.KLABEL(kproduction.klabel.name)])
218 syntax_str = 'syntax ' + self.print(kproduction.sort)
219 if kproduction.items:
220 syntax_str += ' ::= ' + ' '.join([self._print_kouter(pi) for pi in kproduction.items])
221 att_str = self.print(kproduction.att)
222 if att_str:
223 syntax_str += ' ' + att_str
224 return syntax_str
225
226 def _print_ksyntaxsort(self, ksyntaxsort: KSyntaxSort) -> str:
227 sort_str = self.print(ksyntaxsort.sort)
228 att_str = self.print(ksyntaxsort.att)
229 return 'syntax ' + sort_str + ' ' + att_str
230
231 def _print_ksortsynonym(self, ksortsynonym: KSortSynonym) -> str:
232 new_sort_str = self.print(ksortsynonym.new_sort)
233 old_sort_str = self.print(ksortsynonym.old_sort)
234 att_str = self.print(ksortsynonym.att)
235 return 'syntax ' + new_sort_str + ' = ' + old_sort_str + ' ' + att_str
236
237 def _print_ksyntaxlexical(self, ksyntaxlexical: KSyntaxLexical) -> str:
238 name_str = ksyntaxlexical.name
239 regex_str = ksyntaxlexical.regex
240 att_str = self.print(ksyntaxlexical.att)
241 # todo: proper escaping
242 return 'syntax lexical ' + name_str + ' = r"' + regex_str + '" ' + att_str
243
244 def _print_ksyntaxassociativity(self, ksyntaxassociativity: KSyntaxAssociativity) -> str:
245 assoc_str = ksyntaxassociativity.assoc.value
246 tags_str = ' '.join(ksyntaxassociativity.tags)
247 att_str = self.print(ksyntaxassociativity.att)
248 return 'syntax associativity ' + assoc_str + ' ' + tags_str + ' ' + att_str
249
250 def _print_ksyntaxpriority(self, ksyntaxpriority: KSyntaxPriority) -> str:
251 priorities_str = ' > '.join([' '.join(group) for group in ksyntaxpriority.priorities])
252 att_str = self.print(ksyntaxpriority.att)
253 return 'syntax priority ' + priorities_str + ' ' + att_str
254
255 def _print_kbubble(self, kbubble: KBubble) -> str:
256 body = '// KBubble(' + kbubble.sentence_type + ', ' + kbubble.content + ')'
257 att_str = self.print(kbubble.att)
258 return body + ' ' + att_str
259
260 def _print_krule(self, kterm: KRule) -> str:
261 body = '\n '.join(self.print(kterm.body).split('\n'))
262 rule_str = 'rule '
263 if Atts.LABEL in kterm.att:
264 rule_str = rule_str + '[' + kterm.att[Atts.LABEL] + ']:'
265 rule_str = rule_str + ' ' + body
266 atts_str = self.print(kterm.att)
267 if kterm.requires != TRUE:
268 requires_str = 'requires ' + '\n '.join(self._print_kast_bool(kterm.requires).split('\n'))
269 rule_str = rule_str + '\n ' + requires_str
270 if kterm.ensures != TRUE:
271 ensures_str = 'ensures ' + '\n '.join(self._print_kast_bool(kterm.ensures).split('\n'))
272 rule_str = rule_str + '\n ' + ensures_str
273 return rule_str + '\n ' + atts_str
274
275 def _print_kclaim(self, kterm: KClaim) -> str:
276 body = '\n '.join(self.print(kterm.body).split('\n'))
277 rule_str = 'claim '
278 if Atts.LABEL in kterm.att:
279 rule_str = rule_str + '[' + kterm.att[Atts.LABEL] + ']:'
280 rule_str = rule_str + ' ' + body
281 atts_str = self.print(kterm.att)
282 if kterm.requires != TRUE:
283 requires_str = 'requires ' + '\n '.join(self._print_kast_bool(kterm.requires).split('\n'))
284 rule_str = rule_str + '\n ' + requires_str
285 if kterm.ensures != TRUE:
286 ensures_str = 'ensures ' + '\n '.join(self._print_kast_bool(kterm.ensures).split('\n'))
287 rule_str = rule_str + '\n ' + ensures_str
288 return rule_str + '\n ' + atts_str
289
290 def _print_kcontext(self, kcontext: KContext) -> str:
291 body = indent(self.print(kcontext.body))
292 context_str = 'context alias ' + body
293 requires_str = ''
294 atts_str = self.print(kcontext.att)
295 if kcontext.requires != TRUE:
296 requires_str = self.print(kcontext.requires)
297 requires_str = 'requires ' + indent(requires_str)
298 return context_str + '\n ' + requires_str + '\n ' + atts_str
299
300 def _print_katt(self, katt: KAtt) -> str:
301 return katt.pretty
302
303 def _print_kimport(self, kimport: KImport) -> str:
304 return ' '.join(['imports', ('public' if kimport.public else 'private'), kimport.name])
305
306 def _print_kflatmodule(self, kflatmodule: KFlatModule) -> str:
307 name = kflatmodule.name
308 imports = '\n'.join([self._print_kouter(kimport) for kimport in kflatmodule.imports])
309 sentences = '\n\n'.join([self._print_kouter(sentence) for sentence in kflatmodule.sentences])
310 contents = imports + '\n\n' + sentences
311 return 'module ' + name + '\n ' + '\n '.join(contents.split('\n')) + '\n\nendmodule'
312
313 def _print_krequire(self, krequire: KRequire) -> str:
314 return 'requires "' + krequire.require + '"'
315
316 def _print_kdefinition(self, kdefinition: KDefinition) -> str:
317 requires = '\n'.join([self._print_kouter(require) for require in kdefinition.requires])
318 modules = '\n\n'.join([self._print_kouter(module) for module in kdefinition.all_modules])
319 return requires + '\n\n' + modules
320
321 def _print_kast_bool(self, kast: KAst) -> str:
322 """Print out KAST requires/ensures clause.
323
324 - Input: KAST Bool for requires/ensures clause.
325 - Output: Best-effort string representation of KAST term.
326 """
327 _LOGGER.debug(f'_print_kast_bool: {kast}')
328 if type(kast) is KApply and kast.label.name in ['_andBool_', '_orBool_']:
329 clauses = [self._print_kast_bool(c) for c in flatten_label(kast.label.name, kast)]
330 head = kast.label.name.replace('_', ' ')
331 if head == ' orBool ':
332 head = ' orBool '
333 separator = ' ' * (len(head) - 7)
334 spacer = ' ' * len(head)
335
336 def join_sep(s: str) -> str:
337 return ('\n' + separator).join(s.split('\n'))
338
339 clauses = (
340 ['( ' + join_sep(clauses[0])]
341 + [head + '( ' + join_sep(c) for c in clauses[1:]]
342 + [spacer + (')' * len(clauses))]
343 )
344 return '\n'.join(clauses)
345 else:
346 return self.print(kast)
347
348 def _applied_label_str(self, symbol: str) -> Callable[..., str]:
349 return lambda *args: symbol + ' ( ' + ' , '.join(args) + ' )'
350
351
[docs]
352def build_symbol_table(
353 definition: KDefinition,
354 extra_modules: Iterable[KFlatModule] = (),
355 opinionated: bool = False,
356) -> SymbolTable:
357 """Build the unparsing symbol table given a JSON encoded definition.
358
359 - Input: JSON encoded K definition.
360 - Return: Python dictionary mapping klabels to automatically generated unparsers.
361 """
362 symbol_table = {}
363 all_modules = list(definition.all_modules) + ([] if extra_modules is None else list(extra_modules))
364 for module in all_modules:
365 for prod in module.syntax_productions:
366 assert prod.klabel
367 label = prod.klabel.name
368 unparser = unparser_for_production(prod)
369
370 symbol_table[label] = unparser
371 if Atts.SYMBOL in prod.att and Atts.KLABEL in prod.att:
372 symbol_table[prod.att[Atts.KLABEL]] = unparser
373
374 if opinionated:
375 symbol_table['#And'] = lambda c1, c2: c1 + '\n#And ' + c2
376 symbol_table['#Or'] = lambda c1, c2: c1 + '\n#Or\n' + indent(c2, size=4)
377
378 return symbol_table
379
380
[docs]
381def unparser_for_production(prod: KProduction) -> Callable[..., str]:
382 def _unparser(*args: Any) -> str:
383 index = 0
384 result = []
385 num_nonterm = len([item for item in prod.items if type(item) is KNonTerminal])
386 num_named_nonterm = len([item for item in prod.items if type(item) is KNonTerminal and item.name != None])
387 for item in prod.items:
388 if type(item) is KTerminal:
389 result.append(item.value)
390 elif type(item) is KNonTerminal and index < len(args):
391 if num_nonterm == num_named_nonterm:
392 if index == 0:
393 result.append('...')
394 result.append(f'{item.name}:')
395 result.append(args[index])
396 index += 1
397 return ' '.join(result)
398
399 return _unparser
400
401
[docs]
402def indent(text: str, size: int = 2) -> str:
403 return '\n'.join([(' ' * size) + line for line in text.split('\n')])
404
405
[docs]
406def paren(printer: Callable[..., str]) -> Callable[..., str]:
407 return lambda *args: '( ' + printer(*args) + ' )'
408
409
[docs]
410def assoc_with_unit(assoc_join: str, unit: str) -> Callable[..., str]:
411 def _assoc_with_unit(*args: str) -> str:
412 return assoc_join.join(arg for arg in args if arg != unit)
413
414 return _assoc_with_unit