diff --git a/pyk/src/pyk/k2lean4/Prelude.lean b/pyk/src/pyk/k2lean4/Prelude.lean new file mode 100644 index 0000000000..ce14a59bfd --- /dev/null +++ b/pyk/src/pyk/k2lean4/Prelude.lean @@ -0,0 +1,10 @@ +abbrev SortBool : Type := Int +abbrev SortBytes: Type := ByteArray +abbrev SortId : Type := String +abbrev SortInt : Type := Int +abbrev SortString : Type := String +abbrev SortStringBuffer : Type := String + +abbrev ListHook (E : Type) : Type := List E +abbrev MapHook (K : Type) (V : Type) : Type := List (K × V) +abbrev SetHook (E : Type) : Type := List E diff --git a/pyk/src/pyk/k2lean4/__init__.py b/pyk/src/pyk/k2lean4/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/pyk/src/pyk/k2lean4/k2lean4.py b/pyk/src/pyk/k2lean4/k2lean4.py new file mode 100644 index 0000000000..892aef8d7b --- /dev/null +++ b/pyk/src/pyk/k2lean4/k2lean4.py @@ -0,0 +1,68 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING + +from ..kore.internal import CollectionKind +from ..kore.syntax import SortApp +from ..utils import check_type +from .model import Abbrev, Ctor, ExplBinder, Inductive, Module, Signature, Term + +if TYPE_CHECKING: + from ..kore.internal import KoreDefn + from .model import Command + + +@dataclass(frozen=True) +class K2Lean4: + defn: KoreDefn + + def sort_module(self) -> Module: + commands = [] + commands += self._inductives() + commands += self._collections() + return Module(commands=commands) + + def _inductives(self) -> list[Command]: + def is_inductive(sort: str) -> bool: + decl = self.defn.sorts[sort] + return not decl.hooked and 'hasDomainValues' not in decl.attrs_by_key + + sorts = sorted(sort for sort in self.defn.sorts if is_inductive(sort)) + return [self._inductive(sort) for sort in sorts] + + def _inductive(self, sort: str) -> Inductive: + subsorts = sorted(self.defn.subsorts.get(sort, ())) + symbols = sorted(self.defn.constructors.get(sort, ())) + ctors: list[Ctor] = [] + ctors.extend(self._inj_ctor(sort, subsort) for subsort in subsorts) + ctors.extend(self._symbol_ctor(sort, symbol) for symbol in symbols) + return Inductive(sort, Signature((), Term('Type')), ctors=ctors) + + def _inj_ctor(self, sort: str, subsort: str) -> Ctor: + return Ctor(f'inj_{subsort}', Signature((ExplBinder(('x',), Term(subsort)),), Term(sort))) + + def _symbol_ctor(self, sort: str, symbol: str) -> Ctor: + param_sorts = ( + check_type(sort, SortApp).name for sort in self.defn.symbols[symbol].param_sorts + ) # TODO eliminate check_type + binders = tuple(ExplBinder((f'x{i}',), Term(sort)) for i, sort in enumerate(param_sorts)) + symbol = symbol.replace('-', '_') + return Ctor(symbol, Signature(binders, Term(sort))) + + def _collections(self) -> list[Command]: + return [self._collection(sort) for sort in sorted(self.defn.collections)] + + def _collection(self, sort: str) -> Abbrev: + coll = self.defn.collections[sort] + elem = self.defn.symbols[coll.element] + sorts = ' '.join(check_type(sort, SortApp).name for sort in elem.param_sorts) # TODO eliminate check_type + assert sorts + match coll.kind: + case CollectionKind.LIST: + val = Term(f'ListHook {sorts}') + case CollectionKind.MAP: + val = Term(f'MapHook {sorts}') + case CollectionKind.SET: + val = Term(f'SetHook {sorts}') + return Abbrev(sort, val, Signature((), Term('Type'))) diff --git a/pyk/src/pyk/k2lean4/model.py b/pyk/src/pyk/k2lean4/model.py new file mode 100644 index 0000000000..f716a48c36 --- /dev/null +++ b/pyk/src/pyk/k2lean4/model.py @@ -0,0 +1,306 @@ +from __future__ import annotations + +from abc import ABC +from dataclasses import dataclass +from enum import Enum +from typing import TYPE_CHECKING, final + +if TYPE_CHECKING: + from collections.abc import Iterable + + +def indent(text: str, n: int) -> str: + indent = n * ' ' + res = [] + for line in text.splitlines(): + res.append(f'{indent}{line}' if line else '') + return '\n'.join(res) + + +@final +@dataclass(frozen=True) +class Module: + commands: tuple[Command, ...] + + def __init__(self, commands: Iterable[Command] | None = None): + commands = tuple(commands) if commands is not None else () + object.__setattr__(self, 'commands', commands) + + def __str__(self) -> str: + return '\n\n'.join(str(command) for command in self.commands) + + +class Command(ABC): ... + + +class Mutual(Command): + commands: tuple[Command, ...] + + def __init__(self, commands: Iterable[Command] | None = None): + commands = tuple(commands) if commands is not None else () + object.__setattr__(self, 'commands', commands) + + def __str__(self) -> str: + commands = '\n\n'.join(indent(str(command), 2) for command in self.commands) + return f'mutual\n{commands}\nend' + + +class Declaration(Command, ABC): + modifiers: Modifiers | None + + +@final +@dataclass +class Abbrev(Declaration): + ident: DeclId + val: Term # declVal + signature: Signature | None + modifiers: Modifiers | None + + def __init__( + self, + ident: str | DeclId, + val: Term, + signature: Signature | None = None, + modifiers: Modifiers | None = None, + ): + ident = DeclId(ident) if isinstance(ident, str) else ident + object.__setattr__(self, 'ident', ident) + object.__setattr__(self, 'val', val) + object.__setattr__(self, 'signature', signature) + object.__setattr__(self, 'modifiers', modifiers) + + def __str__(self) -> str: + modifiers = f'{self.modifiers} ' if self.modifiers else '' + signature = f' {self.signature}' if self.signature else '' + return f'{modifiers} abbrev {self.ident}{signature} := {self.val}' + + +@final +@dataclass(frozen=True) +class Inductive(Declaration): + ident: DeclId + signature: Signature | None + ctors: tuple[Ctor, ...] + deriving: tuple[str, ...] + modifiers: Modifiers | None + + def __init__( + self, + ident: str | DeclId, + signature: Signature | None = None, + ctors: Iterable[Ctor] | None = None, + deriving: Iterable[str] | None = None, + modifiers: Modifiers | None = None, + ): + ident = DeclId(ident) if isinstance(ident, str) else ident + ctors = tuple(ctors) if ctors is not None else () + deriving = tuple(deriving) if deriving is not None else () + object.__setattr__(self, 'ident', ident) + object.__setattr__(self, 'signature', signature) + object.__setattr__(self, 'ctors', ctors) + object.__setattr__(self, 'deriving', deriving) + object.__setattr__(self, 'modifiers', modifiers) + + def __str__(self) -> str: + modifiers = f'{self.modifiers} ' if self.modifiers else '' + signature = f' {self.signature}' if self.signature else '' + where = ' where' if self.ctors else '' + deriving = ', '.join(self.deriving) + + lines = [] + lines.append(f'{modifiers}inductive {self.ident}{signature}{where}') + for ctor in self.ctors: + lines.append(f' | {ctor}') + if deriving: + lines.append(f' deriving {deriving}') + return '\n'.join(lines) + + +@final +@dataclass(frozen=True) +class DeclId: + val: str + uvars: tuple[str, ...] + + def __init__(self, val: str, uvars: Iterable[str] | None = None): + uvars = tuple(uvars) if uvars is not None else () + object.__setattr__(self, 'val', val) + object.__setattr__(self, 'uvars', uvars) + + def __str__(self) -> str: + uvars = ', '.join(self.uvars) + uvars = '.{' + uvars + '}' if uvars else '' + return f'{self.val}{uvars}' + + +@final +@dataclass(frozen=True) +class Ctor: + ident: str + signature: Signature | None = None + modifiers: Modifiers | None = None + + def __str__(self) -> str: + modifiers = f'{self.modifiers} ' if self.modifiers else '' + signature = f' {self.signature}' if self.signature else '' + return f'{modifiers}{self.ident}{signature}' + + +@final +@dataclass(frozen=True) +class Signature: + binders: tuple[Binder, ...] + ty: Term | None + + def __init__(self, binders: Iterable[Binder] | None = None, ty: Term | None = None): + binders = tuple(binders) if binders is not None else () + object.__setattr__(self, 'binders', binders) + object.__setattr__(self, 'ty', ty) + + def __str__(self) -> str: + binders = ' '.join(str(binder) for binder in self.binders) + sep = ' ' if self.binders else '' + ty = f'{sep}: {self.ty}' if self.ty else '' + return f'{binders}{ty}' + + +class Binder(ABC): ... + + +class BracketBinder(Binder, ABC): ... + + +@final +@dataclass(frozen=True) +class ExplBinder(BracketBinder): + idents: tuple[str, ...] + ty: Term | None + + def __init__(self, idents: Iterable[str], ty: Term | None = None): + object.__setattr__(self, 'idents', tuple(idents)) + object.__setattr__(self, 'ty', ty) + + def __str__(self) -> str: + idents = ' '.join(self.idents) + ty = '' if self.ty is None else f' : {self.ty}' + return f'({idents}{ty})' + + +@final +@dataclass(frozen=True) +class ImplBinder(BracketBinder): + idents: tuple[str, ...] + ty: Term | None + strict: bool + + def __init__(self, idents: Iterable[str], ty: Term | None = None, *, strict: bool | None = None): + object.__setattr__(self, 'idents', tuple(idents)) + object.__setattr__(self, 'ty', ty) + object.__setattr__(self, 'strict', bool(strict)) + + def __str__(self) -> str: + ldelim, rdelim = ['⦃', '⦄'] if self.strict else ['{', '}'] + idents = ' '.join(self.idents) + ty = '' if self.ty is None else f' : {self.ty}' + return f'{ldelim}{idents}{ty}{rdelim}' + + +@final +@dataclass(frozen=True) +class InstBinder(BracketBinder): + ty: Term + ident: str | None + + def __init__(self, ty: Term, ident: str | None = None): + object.__setattr__(self, 'ty', ty) + object.__setattr__(self, 'ident', ident) + + def __str__(self) -> str: + ident = f'{self.ident} : ' if self.ident else '' + return f'[{ident}{self.ty}]' + + +@final +@dataclass(frozen=True) +class Term: + term: str # TODO: refine + + def __str__(self) -> str: + return self.term + + +@final +@dataclass(frozen=True) +class Modifiers: + attrs: tuple[Attr, ...] + visibility: Visibility | None + noncomputable: bool + unsafe: bool + totality: Totality | None + + def __init__( + self, + *, + attrs: Iterable[str | Attr] | None = None, + visibility: str | Visibility | None = None, + noncomputable: bool | None = None, + unsafe: bool | None = None, + totality: str | Totality | None = None, + ): + attrs = tuple(Attr(attr) if isinstance(attr, str) else attr for attr in attrs) if attrs is not None else () + visibility = Visibility(visibility) if isinstance(visibility, str) else visibility + noncomputable = bool(noncomputable) + unsafe = bool(unsafe) + totality = Totality(totality) if isinstance(totality, str) else totality + object.__setattr__(self, 'attrs', attrs) + object.__setattr__(self, 'visibility', visibility) + object.__setattr__(self, 'noncomputable', noncomputable) + object.__setattr__(self, 'unsafe', unsafe) + object.__setattr__(self, 'totality', totality) + + def __str__(self) -> str: + chunks = [] + if self.attrs: + attrs = ', '.join(str(attr) for attr in self.attrs) + chunks.append(f'@[{attrs}]') + if self.visibility: + chunks.append(self.visibility.value) + if self.noncomputable: + chunks.append('noncomputable') + if self.unsafe: + chunks.append('unsafe') + if self.totality: + chunks.append(self.totality.value) + return ' '.join(chunks) + + +@final +@dataclass(frozen=True) +class Attr: + attr: str + kind: AttrKind | None + + def __init__(self, attr: str, kind: AttrKind | None = None): + object.__setattr__(self, 'attr', attr) + object.__setattr__(self, 'kind', kind) + + def __str__(self) -> str: + if self.kind: + return f'{self.kind.value} {self.attr}' + return self.attr + + +class AttrKind(Enum): + SCOPED = 'scoped' + LOCAL = 'local' + + +class Visibility(Enum): + PRIVATE = 'private' + PROTECTED = 'protected' + + +class Totality(Enum): + PARTIAL = 'partial' + NONREC = 'nonrec' diff --git a/pyk/src/pyk/kore/internal.py b/pyk/src/pyk/kore/internal.py index 9d21f91f9e..95eb918bc7 100644 --- a/pyk/src/pyk/kore/internal.py +++ b/pyk/src/pyk/kore/internal.py @@ -1,18 +1,66 @@ from __future__ import annotations from dataclasses import dataclass +from enum import Enum from functools import cached_property -from typing import TYPE_CHECKING, final +from typing import TYPE_CHECKING, NamedTuple, final from ..utils import FrozenDict, POSet, check_type from .manip import collect_symbols from .rule import FunctionRule, RewriteRule, Rule -from .syntax import App, Axiom, SortApp, SortDecl, Symbol, SymbolDecl +from .syntax import App, Axiom, SortApp, SortDecl, String, Symbol, SymbolDecl if TYPE_CHECKING: from .syntax import Definition +class CollectionKind(Enum): + LIST = 'List' + MAP = 'Map' + SET = 'Set' + + +class Collection(NamedTuple): + sort: str + concat: str + element: str + unit: str + kind: CollectionKind + + @staticmethod + def from_decl(decl: SortDecl) -> Collection: + if not 'element' in decl.attrs_by_key: + raise ValueError(f'Not a sort declaration: {decl.text}') + sort = decl.name + concat = Collection._extract_label(decl, 'concat') + element = Collection._extract_label(decl, 'element') + unit = Collection._extract_label(decl, 'unit') + kind = CollectionKind(Collection._extract_string(decl, 'hook').split('.')[1]) # 'MAP.Map' -> CollectionKind.MAP + return Collection( + sort=sort, + concat=concat, + element=element, + unit=unit, + kind=kind, + ) + + @staticmethod + def _extract_label(decl: SortDecl, attr: str) -> str: + match decl.attrs_by_key[attr]: + case App(attr, args=(App(res),)): + return res + case _: + raise AssertionError() + + @staticmethod + def _extract_string(decl: SortDecl, attr: str) -> str: + match decl.attrs_by_key[attr]: + case App(attr, args=(String(res),)): + return res + case _: + raise AssertionError() + + @final @dataclass(frozen=True) class KoreDefn: @@ -55,7 +103,7 @@ def from_definition(defn: Definition) -> KoreDefn: ) @cached_property - def ctor_symbols(self) -> FrozenDict[str, tuple[str, ...]]: + def constructors(self) -> FrozenDict[str, tuple[str, ...]]: grouped: dict[str, list[str]] = {} for symbol, decl in self.symbols.items(): if not 'constructor' in decl.attrs_by_key: @@ -64,6 +112,17 @@ def ctor_symbols(self) -> FrozenDict[str, tuple[str, ...]]: grouped.setdefault(sort, []).append(symbol) return FrozenDict((sort, tuple(symbols)) for sort, symbols in grouped.items()) + @cached_property + def collections(self) -> FrozenDict[str, Collection]: + colls: dict[str, Collection] = {} + for sort, decl in self.sorts.items(): + if not 'element' in decl.attrs_by_key: + continue + coll = Collection.from_decl(decl) + assert sort == coll.sort + colls[sort] = coll + return FrozenDict(colls) + def let( self, *, @@ -130,7 +189,13 @@ def _config_symbols(self) -> set[str]: if sort in done: continue done.add(sort) - symbols = self.ctor_symbols.get(sort, ()) + + symbols: list[str] = [] + if sort in self.collections: + coll = self.collections[sort] + symbols += (coll.concat, coll.element, coll.unit) + symbols += self.constructors.get(sort, ()) + pending.extend(sort for symbol in symbols for sort in self._symbol_sorts(symbol)) res.update(symbols) return res