1from __future__ import annotations
  2
  3import logging
  4from collections.abc import Callable
  5from functools import cached_property
  6from typing import TYPE_CHECKING
  7
  8from ..prelude.kbool import TRUE
  9from .att import Atts, KAtt
 10from .inner import KApply, KAs, KInner, KLabel, KRewrite, KSequence, KSort, KToken, KVariable
 11from .manip import flatten_label, sort_ac_collections, undo_aliases
 12from .outer import (
 13    KBubble,
 14    KClaim,
 15    KContext,
 16    KDefinition,
 17    KFlatModule,
 18    KImport,
 19    KNonTerminal,
 20    KOuter,
 21    KProduction,
 22    KRegexTerminal,
 23    KRequire,
 24    KRule,
 25    KRuleLike,
 26    KSortSynonym,
 27    KSyntaxAssociativity,
 28    KSyntaxLexical,
 29    KSyntaxPriority,
 30    KSyntaxSort,
 31    KTerminal,
 32)
 33
 34if TYPE_CHECKING:
 35    from collections.abc import Iterable
 36    from typing import Any, Final, TypeVar
 37
 38    from .kast import KAst
 39
 40    RL = TypeVar('RL', bound='KRuleLike')
 41
 42_LOGGER: Final = logging.getLogger(__name__)
 43
 44SymbolTable = dict[str, Callable[..., str]]
 45
 46
[docs]
 47class PrettyPrinter:
 48    definition: KDefinition
 49    _extra_unparsing_modules: Iterable[KFlatModule]
 50    _patch_symbol_table: Callable[[SymbolTable], None] | None
 51    _unalias: bool
 52    _sort_collections: bool
 53
 54    def __init__(
 55        self,
 56        definition: KDefinition,
 57        extra_unparsing_modules: Iterable[KFlatModule] = (),
 58        patch_symbol_table: Callable[[SymbolTable], None] | None = None,
 59        unalias: bool = True,
 60        sort_collections: bool = False,
 61    ):
 62        self.definition = definition
 63        self._extra_unparsing_modules = extra_unparsing_modules
 64        self._patch_symbol_table = patch_symbol_table
 65        self._unalias = unalias
 66        self._sort_collections = sort_collections
 67
 68    @cached_property
 69    def symbol_table(self) -> SymbolTable:
 70        symb_table = build_symbol_table(
 71            self.definition,
 72            extra_modules=self._extra_unparsing_modules,
 73            opinionated=True,
 74        )
 75        if self._patch_symbol_table is not None:
 76            self._patch_symbol_table(symb_table)
 77        return symb_table
 78
[docs]
 79    def print(self, kast: KAst) -> str:
 80        """Print out KAST terms/outer syntax.
 81        -   Input: KAST term.
 82        -   Output: Best-effort string representation of KAST term.
 83        """
 84        _LOGGER.debug(f'Unparsing: {kast}')
 85        if type(kast) is KAtt:
 86            return self._print_katt(kast)
 87        if type(kast) is KSort:
 88            return self._print_ksort(kast)
 89        if type(kast) is KLabel:
 90            return self._print_klabel(kast)
 91        elif isinstance(kast, KOuter):
 92            return self._print_kouter(kast)
 93        elif isinstance(kast, KInner):
 94            if self._unalias:
 95                kast = undo_aliases(self.definition, kast)
 96            if self._sort_collections:
 97                kast = sort_ac_collections(kast)
 98            return self._print_kinner(kast)
 99        raise AssertionError(f'Error unparsing: {kast}') 
100
101    def _print_kouter(self, kast: KOuter) -> str:
102        match kast:
103            case KTerminal():
104                return self._print_kterminal(kast)
105            case KRegexTerminal():
106                return self._print_kregexterminal(kast)
107            case KNonTerminal():
108                return self._print_knonterminal(kast)
109            case KProduction():
110                return self._print_kproduction(kast)
111            case KSyntaxSort():
112                return self._print_ksyntaxsort(kast)
113            case KSortSynonym():
114                return self._print_ksortsynonym(kast)
115            case KSyntaxLexical():
116                return self._print_ksyntaxlexical(kast)
117            case KSyntaxAssociativity():
118                return self._print_ksyntaxassociativity(kast)
119            case KSyntaxPriority():
120                return self._print_ksyntaxpriority(kast)
121            case KBubble():
122                return self._print_kbubble(kast)
123            case KRule():
124                return self._print_krule(kast)
125            case KClaim():
126                return self._print_kclaim(kast)
127            case KContext():
128                return self._print_kcontext(kast)
129            case KImport():
130                return self._print_kimport(kast)
131            case KFlatModule():
132                return self._print_kflatmodule(kast)
133            case KRequire():
134                return self._print_krequire(kast)
135            case KDefinition():
136                return self._print_kdefinition(kast)
137            case _:
138                raise AssertionError(f'Error unparsing: {kast}')
139
140    def _print_kinner(self, kast: KInner) -> str:
141        match kast:
142            case KVariable():
143                return self._print_kvariable(kast)
144            case KToken():
145                return self._print_ktoken(kast)
146            case KApply():
147                return self._print_kapply(kast)
148            case KAs():
149                return self._print_kas(kast)
150            case KRewrite():
151                return self._print_krewrite(kast)
152            case KSequence():
153                return self._print_ksequence(kast)
154            case _:
155                raise AssertionError(f'Error unparsing: {kast}')
156
157    def _print_ksort(self, ksort: KSort) -> str:
158        return ksort.name
159
160    def _print_klabel(self, klabel: KLabel) -> str:
161        return klabel.name
162
163    def _print_kvariable(self, kvariable: KVariable) -> str:
164        sort = kvariable.sort
165        if not sort:
166            return kvariable.name
167        return kvariable.name + ':' + sort.name
168
169    def _print_ktoken(self, ktoken: KToken) -> str:
170        return ktoken.token
171
172    def _print_kapply(self, kapply: KApply) -> str:
173        label = kapply.label.name
174        args = kapply.args
175        unparsed_args = [self._print_kinner(arg) for arg in args]
176        if kapply.is_cell:
177            cell_contents = '\n'.join(unparsed_args).rstrip()
178            cell_str = label + '\n' + indent(cell_contents) + '\n</' + label[1:]
179            return cell_str.rstrip()
180        unparser = self._applied_label_str(label) if label not in self.symbol_table else self.symbol_table[label]
181        return unparser(*unparsed_args)
182
183    def _print_kas(self, kas: KAs) -> str:
184        pattern_str = self._print_kinner(kas.pattern)
185        alias_str = self._print_kinner(kas.alias)
186        return pattern_str + ' #as ' + alias_str
187
188    def _print_krewrite(self, krewrite: KRewrite) -> str:
189        lhs_str = self._print_kinner(krewrite.lhs)
190        rhs_str = self._print_kinner(krewrite.rhs)
191        return '( ' + lhs_str + ' => ' + rhs_str + ' )'
192
193    def _print_ksequence(self, ksequence: KSequence) -> str:
194        if ksequence.arity == 0:
195            # TODO: Would be nice to say `return self._print_kinner(EMPTY_K)`
196            return '.K'
197        if ksequence.arity == 1:
198            return self._print_kinner(ksequence.items[0]) + ' ~> .K'
199        unparsed_k_seq = '\n~> '.join([self._print_kinner(item) for item in ksequence.items[0:-1]])
200        if ksequence.items[-1] == KToken('...', KSort('K')):
201            unparsed_k_seq = unparsed_k_seq + '\n' + self._print_kinner(KToken('...', KSort('K')))
202        else:
203            unparsed_k_seq = unparsed_k_seq + '\n~> ' + self._print_kinner(ksequence.items[-1])
204        return unparsed_k_seq
205
206    def _print_kterminal(self, kterminal: KTerminal) -> str:
207        return '"' + kterminal.value + '"'
208
209    def _print_kregexterminal(self, kregexterminal: KRegexTerminal) -> str:
210        return 'r"' + kregexterminal.regex + '"'
211
212    def _print_knonterminal(self, knonterminal: KNonTerminal) -> str:
213        return self.print(knonterminal.sort)
214
215    def _print_kproduction(self, kproduction: KProduction) -> str:
216        if Atts.KLABEL not in kproduction.att and kproduction.klabel:
217            kproduction = kproduction.update_atts([Atts.KLABEL(kproduction.klabel.name)])
218        syntax_str = 'syntax ' + self.print(kproduction.sort)
219        if kproduction.items:
220            syntax_str += ' ::= ' + ' '.join([self._print_kouter(pi) for pi in kproduction.items])
221        att_str = self.print(kproduction.att)
222        if att_str:
223            syntax_str += ' ' + att_str
224        return syntax_str
225
226    def _print_ksyntaxsort(self, ksyntaxsort: KSyntaxSort) -> str:
227        sort_str = self.print(ksyntaxsort.sort)
228        att_str = self.print(ksyntaxsort.att)
229        return 'syntax ' + sort_str + ' ' + att_str
230
231    def _print_ksortsynonym(self, ksortsynonym: KSortSynonym) -> str:
232        new_sort_str = self.print(ksortsynonym.new_sort)
233        old_sort_str = self.print(ksortsynonym.old_sort)
234        att_str = self.print(ksortsynonym.att)
235        return 'syntax ' + new_sort_str + ' = ' + old_sort_str + ' ' + att_str
236
237    def _print_ksyntaxlexical(self, ksyntaxlexical: KSyntaxLexical) -> str:
238        name_str = ksyntaxlexical.name
239        regex_str = ksyntaxlexical.regex
240        att_str = self.print(ksyntaxlexical.att)
241        # todo: proper escaping
242        return 'syntax lexical ' + name_str + ' = r"' + regex_str + '" ' + att_str
243
244    def _print_ksyntaxassociativity(self, ksyntaxassociativity: KSyntaxAssociativity) -> str:
245        assoc_str = ksyntaxassociativity.assoc.value
246        tags_str = ' '.join(ksyntaxassociativity.tags)
247        att_str = self.print(ksyntaxassociativity.att)
248        return 'syntax associativity ' + assoc_str + ' ' + tags_str + ' ' + att_str
249
250    def _print_ksyntaxpriority(self, ksyntaxpriority: KSyntaxPriority) -> str:
251        priorities_str = ' > '.join([' '.join(group) for group in ksyntaxpriority.priorities])
252        att_str = self.print(ksyntaxpriority.att)
253        return 'syntax priority ' + priorities_str + ' ' + att_str
254
255    def _print_kbubble(self, kbubble: KBubble) -> str:
256        body = '// KBubble(' + kbubble.sentence_type + ', ' + kbubble.content + ')'
257        att_str = self.print(kbubble.att)
258        return body + ' ' + att_str
259
260    def _print_krule(self, kterm: KRule) -> str:
261        body = '\n     '.join(self.print(kterm.body).split('\n'))
262        rule_str = 'rule '
263        if Atts.LABEL in kterm.att:
264            rule_str = rule_str + '[' + kterm.att[Atts.LABEL] + ']:'
265        rule_str = rule_str + ' ' + body
266        atts_str = self.print(kterm.att)
267        if kterm.requires != TRUE:
268            requires_str = 'requires ' + '\n  '.join(self._print_kast_bool(kterm.requires).split('\n'))
269            rule_str = rule_str + '\n  ' + requires_str
270        if kterm.ensures != TRUE:
271            ensures_str = 'ensures ' + '\n  '.join(self._print_kast_bool(kterm.ensures).split('\n'))
272            rule_str = rule_str + '\n   ' + ensures_str
273        return rule_str + '\n  ' + atts_str
274
275    def _print_kclaim(self, kterm: KClaim) -> str:
276        body = '\n     '.join(self.print(kterm.body).split('\n'))
277        rule_str = 'claim '
278        if Atts.LABEL in kterm.att:
279            rule_str = rule_str + '[' + kterm.att[Atts.LABEL] + ']:'
280        rule_str = rule_str + ' ' + body
281        atts_str = self.print(kterm.att)
282        if kterm.requires != TRUE:
283            requires_str = 'requires ' + '\n  '.join(self._print_kast_bool(kterm.requires).split('\n'))
284            rule_str = rule_str + '\n  ' + requires_str
285        if kterm.ensures != TRUE:
286            ensures_str = 'ensures ' + '\n  '.join(self._print_kast_bool(kterm.ensures).split('\n'))
287            rule_str = rule_str + '\n   ' + ensures_str
288        return rule_str + '\n  ' + atts_str
289
290    def _print_kcontext(self, kcontext: KContext) -> str:
291        body = indent(self.print(kcontext.body))
292        context_str = 'context alias ' + body
293        requires_str = ''
294        atts_str = self.print(kcontext.att)
295        if kcontext.requires != TRUE:
296            requires_str = self.print(kcontext.requires)
297            requires_str = 'requires ' + indent(requires_str)
298        return context_str + '\n  ' + requires_str + '\n  ' + atts_str
299
300    def _print_katt(self, katt: KAtt) -> str:
301        return katt.pretty
302
303    def _print_kimport(self, kimport: KImport) -> str:
304        return ' '.join(['imports', ('public' if kimport.public else 'private'), kimport.name])
305
306    def _print_kflatmodule(self, kflatmodule: KFlatModule) -> str:
307        name = kflatmodule.name
308        imports = '\n'.join([self._print_kouter(kimport) for kimport in kflatmodule.imports])
309        sentences = '\n\n'.join([self._print_kouter(sentence) for sentence in kflatmodule.sentences])
310        contents = imports + '\n\n' + sentences
311        return 'module ' + name + '\n    ' + '\n    '.join(contents.split('\n')) + '\n\nendmodule'
312
313    def _print_krequire(self, krequire: KRequire) -> str:
314        return 'requires "' + krequire.require + '"'
315
316    def _print_kdefinition(self, kdefinition: KDefinition) -> str:
317        requires = '\n'.join([self._print_kouter(require) for require in kdefinition.requires])
318        modules = '\n\n'.join([self._print_kouter(module) for module in kdefinition.all_modules])
319        return requires + '\n\n' + modules
320
321    def _print_kast_bool(self, kast: KAst) -> str:
322        """Print out KAST requires/ensures clause.
323
324        -   Input: KAST Bool for requires/ensures clause.
325        -   Output: Best-effort string representation of KAST term.
326        """
327        _LOGGER.debug(f'_print_kast_bool: {kast}')
328        if type(kast) is KApply and kast.label.name in ['_andBool_', '_orBool_']:
329            clauses = [self._print_kast_bool(c) for c in flatten_label(kast.label.name, kast)]
330            head = kast.label.name.replace('_', ' ')
331            if head == ' orBool ':
332                head = '  orBool '
333            separator = ' ' * (len(head) - 7)
334            spacer = ' ' * len(head)
335
336            def join_sep(s: str) -> str:
337                return ('\n' + separator).join(s.split('\n'))
338
339            clauses = (
340                ['( ' + join_sep(clauses[0])]
341                + [head + '( ' + join_sep(c) for c in clauses[1:]]
342                + [spacer + (')' * len(clauses))]
343            )
344            return '\n'.join(clauses)
345        else:
346            return self.print(kast)
347
348    def _applied_label_str(self, symbol: str) -> Callable[..., str]:
349        return lambda *args: symbol + ' ( ' + ' , '.join(args) + ' )' 
350
351
[docs]
352def build_symbol_table(
353    definition: KDefinition,
354    extra_modules: Iterable[KFlatModule] = (),
355    opinionated: bool = False,
356) -> SymbolTable:
357    """Build the unparsing symbol table given a JSON encoded definition.
358
359    -   Input: JSON encoded K definition.
360    -   Return: Python dictionary mapping klabels to automatically generated unparsers.
361    """
362    symbol_table = {}
363    all_modules = list(definition.all_modules) + ([] if extra_modules is None else list(extra_modules))
364    for module in all_modules:
365        for prod in module.syntax_productions:
366            assert prod.klabel
367            label = prod.klabel.name
368            unparser = unparser_for_production(prod)
369
370            symbol_table[label] = unparser
371            if Atts.SYMBOL in prod.att and Atts.KLABEL in prod.att:
372                symbol_table[prod.att[Atts.KLABEL]] = unparser
373
374    if opinionated:
375        symbol_table['#And'] = lambda c1, c2: c1 + '\n#And ' + c2
376        symbol_table['#Or'] = lambda c1, c2: c1 + '\n#Or\n' + indent(c2, size=4)
377
378    return symbol_table 
379
380
[docs]
381def unparser_for_production(prod: KProduction) -> Callable[..., str]:
382    def _unparser(*args: Any) -> str:
383        index = 0
384        result = []
385        num_nonterm = len([item for item in prod.items if type(item) is KNonTerminal])
386        num_named_nonterm = len([item for item in prod.items if type(item) is KNonTerminal and item.name != None])
387        for item in prod.items:
388            if type(item) is KTerminal:
389                result.append(item.value)
390            elif type(item) is KNonTerminal and index < len(args):
391                if num_nonterm == num_named_nonterm:
392                    if index == 0:
393                        result.append('...')
394                    result.append(f'{item.name}:')
395                result.append(args[index])
396                index += 1
397        return ' '.join(result)
398
399    return _unparser 
400
401
[docs]
402def indent(text: str, size: int = 2) -> str:
403    return '\n'.join([(' ' * size) + line for line in text.split('\n')]) 
404
405
[docs]
406def paren(printer: Callable[..., str]) -> Callable[..., str]:
407    return lambda *args: '( ' + printer(*args) + ' )' 
408
409
[docs]
410def assoc_with_unit(assoc_join: str, unit: str) -> Callable[..., str]:
411    def _assoc_with_unit(*args: str) -> str:
412        return assoc_join.join(arg for arg in args if arg != unit)
413
414    return _assoc_with_unit