Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generate Lean 4 type definitions from a KORE definition #4717

Open
wants to merge 6 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions pyk/src/pyk/k2lean4/Prelude.lean
JuanCoRo marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
abbrev SortBool : Type := Int
abbrev SortBytes: Type := ByteArray
abbrev SortId : Type := String
abbrev SortInt : Type := Int
abbrev SortString : Type := String
abbrev SortStringBuffer : Type := String

abbrev ListHook (E : Type) : Type := List E
abbrev MapHook (K : Type) (V : Type) : Type := List (K × V)
abbrev SetHook (E : Type) : Type := List E
Empty file added pyk/src/pyk/k2lean4/__init__.py
Empty file.
68 changes: 68 additions & 0 deletions pyk/src/pyk/k2lean4/k2lean4.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import TYPE_CHECKING

from ..kore.internal import CollectionKind
from ..kore.syntax import SortApp
from ..utils import check_type
from .model import Abbrev, Ctor, ExplBinder, Inductive, Module, Signature, Term

if TYPE_CHECKING:
from ..kore.internal import KoreDefn
from .model import Command


@dataclass(frozen=True)
class K2Lean4:
defn: KoreDefn

def sort_module(self) -> Module:
commands = []
commands += self._inductives()
commands += self._collections()
return Module(commands=commands)

def _inductives(self) -> list[Command]:
def is_inductive(sort: str) -> bool:
decl = self.defn.sorts[sort]
return not decl.hooked and 'hasDomainValues' not in decl.attrs_by_key

sorts = sorted(sort for sort in self.defn.sorts if is_inductive(sort))
return [self._inductive(sort) for sort in sorts]

def _inductive(self, sort: str) -> Inductive:
subsorts = sorted(self.defn.subsorts.get(sort, ()))
symbols = sorted(self.defn.constructors.get(sort, ()))
ctors: list[Ctor] = []
ctors.extend(self._inj_ctor(sort, subsort) for subsort in subsorts)
ctors.extend(self._symbol_ctor(sort, symbol) for symbol in symbols)
return Inductive(sort, Signature((), Term('Type')), ctors=ctors)

def _inj_ctor(self, sort: str, subsort: str) -> Ctor:
return Ctor(f'inj_{subsort}', Signature((ExplBinder(('x',), Term(subsort)),), Term(sort)))

def _symbol_ctor(self, sort: str, symbol: str) -> Ctor:
param_sorts = (
check_type(sort, SortApp).name for sort in self.defn.symbols[symbol].param_sorts
) # TODO eliminate check_type
binders = tuple(ExplBinder((f'x{i}',), Term(sort)) for i, sort in enumerate(param_sorts))
symbol = symbol.replace('-', '_')
return Ctor(symbol, Signature(binders, Term(sort)))

def _collections(self) -> list[Command]:
return [self._collection(sort) for sort in sorted(self.defn.collections)]

def _collection(self, sort: str) -> Abbrev:
coll = self.defn.collections[sort]
elem = self.defn.symbols[coll.element]
sorts = ' '.join(check_type(sort, SortApp).name for sort in elem.param_sorts) # TODO eliminate check_type
assert sorts
match coll.kind:
case CollectionKind.LIST:
val = Term(f'ListHook {sorts}')
case CollectionKind.MAP:
val = Term(f'MapHook {sorts}')
case CollectionKind.SET:
val = Term(f'SetHook {sorts}')
return Abbrev(sort, val, Signature((), Term('Type')))
306 changes: 306 additions & 0 deletions pyk/src/pyk/k2lean4/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,306 @@
from __future__ import annotations

from abc import ABC
from dataclasses import dataclass
from enum import Enum
from typing import TYPE_CHECKING, final

if TYPE_CHECKING:
from collections.abc import Iterable


def indent(text: str, n: int) -> str:
indent = n * ' '
res = []
for line in text.splitlines():
res.append(f'{indent}{line}' if line else '')
return '\n'.join(res)


@final
@dataclass(frozen=True)
class Module:
commands: tuple[Command, ...]

def __init__(self, commands: Iterable[Command] | None = None):
commands = tuple(commands) if commands is not None else ()
object.__setattr__(self, 'commands', commands)

def __str__(self) -> str:
return '\n\n'.join(str(command) for command in self.commands)


class Command(ABC): ...


class Mutual(Command):
commands: tuple[Command, ...]

def __init__(self, commands: Iterable[Command] | None = None):
commands = tuple(commands) if commands is not None else ()
object.__setattr__(self, 'commands', commands)

def __str__(self) -> str:
commands = '\n\n'.join(indent(str(command), 2) for command in self.commands)
return f'mutual\n{commands}\nend'


class Declaration(Command, ABC):
modifiers: Modifiers | None


@final
@dataclass
class Abbrev(Declaration):
ident: DeclId
val: Term # declVal
signature: Signature | None
modifiers: Modifiers | None

def __init__(
self,
ident: str | DeclId,
val: Term,
signature: Signature | None = None,
modifiers: Modifiers | None = None,
):
ident = DeclId(ident) if isinstance(ident, str) else ident
object.__setattr__(self, 'ident', ident)
object.__setattr__(self, 'val', val)
object.__setattr__(self, 'signature', signature)
object.__setattr__(self, 'modifiers', modifiers)

def __str__(self) -> str:
modifiers = f'{self.modifiers} ' if self.modifiers else ''
signature = f' {self.signature}' if self.signature else ''
return f'{modifiers} abbrev {self.ident}{signature} := {self.val}'


@final
@dataclass(frozen=True)
class Inductive(Declaration):
ident: DeclId
signature: Signature | None
ctors: tuple[Ctor, ...]
deriving: tuple[str, ...]
modifiers: Modifiers | None

def __init__(
self,
ident: str | DeclId,
signature: Signature | None = None,
ctors: Iterable[Ctor] | None = None,
deriving: Iterable[str] | None = None,
modifiers: Modifiers | None = None,
):
ident = DeclId(ident) if isinstance(ident, str) else ident
ctors = tuple(ctors) if ctors is not None else ()
deriving = tuple(deriving) if deriving is not None else ()
object.__setattr__(self, 'ident', ident)
object.__setattr__(self, 'signature', signature)
object.__setattr__(self, 'ctors', ctors)
object.__setattr__(self, 'deriving', deriving)
object.__setattr__(self, 'modifiers', modifiers)

def __str__(self) -> str:
modifiers = f'{self.modifiers} ' if self.modifiers else ''
signature = f' {self.signature}' if self.signature else ''
where = ' where' if self.ctors else ''
deriving = ', '.join(self.deriving)

lines = []
lines.append(f'{modifiers}inductive {self.ident}{signature}{where}')
for ctor in self.ctors:
lines.append(f' | {ctor}')
if deriving:
lines.append(f' deriving {deriving}')
return '\n'.join(lines)


@final
@dataclass(frozen=True)
class DeclId:
val: str
uvars: tuple[str, ...]

def __init__(self, val: str, uvars: Iterable[str] | None = None):
uvars = tuple(uvars) if uvars is not None else ()
object.__setattr__(self, 'val', val)
object.__setattr__(self, 'uvars', uvars)

def __str__(self) -> str:
uvars = ', '.join(self.uvars)
uvars = '.{' + uvars + '}' if uvars else ''
return f'{self.val}{uvars}'


@final
@dataclass(frozen=True)
class Ctor:
ident: str
signature: Signature | None = None
modifiers: Modifiers | None = None

def __str__(self) -> str:
modifiers = f'{self.modifiers} ' if self.modifiers else ''
signature = f' {self.signature}' if self.signature else ''
return f'{modifiers}{self.ident}{signature}'


@final
@dataclass(frozen=True)
class Signature:
binders: tuple[Binder, ...]
ty: Term | None

def __init__(self, binders: Iterable[Binder] | None = None, ty: Term | None = None):
binders = tuple(binders) if binders is not None else ()
object.__setattr__(self, 'binders', binders)
object.__setattr__(self, 'ty', ty)

def __str__(self) -> str:
binders = ' '.join(str(binder) for binder in self.binders)
sep = ' ' if self.binders else ''
ty = f'{sep}: {self.ty}' if self.ty else ''
return f'{binders}{ty}'


class Binder(ABC): ...


class BracketBinder(Binder, ABC): ...


@final
@dataclass(frozen=True)
class ExplBinder(BracketBinder):
idents: tuple[str, ...]
ty: Term | None

def __init__(self, idents: Iterable[str], ty: Term | None = None):
object.__setattr__(self, 'idents', tuple(idents))
object.__setattr__(self, 'ty', ty)

def __str__(self) -> str:
idents = ' '.join(self.idents)
ty = '' if self.ty is None else f' : {self.ty}'
return f'({idents}{ty})'


@final
@dataclass(frozen=True)
class ImplBinder(BracketBinder):
idents: tuple[str, ...]
ty: Term | None
strict: bool

def __init__(self, idents: Iterable[str], ty: Term | None = None, *, strict: bool | None = None):
object.__setattr__(self, 'idents', tuple(idents))
object.__setattr__(self, 'ty', ty)
object.__setattr__(self, 'strict', bool(strict))

def __str__(self) -> str:
ldelim, rdelim = ['⦃', '⦄'] if self.strict else ['{', '}']
idents = ' '.join(self.idents)
ty = '' if self.ty is None else f' : {self.ty}'
return f'{ldelim}{idents}{ty}{rdelim}'


@final
@dataclass(frozen=True)
class InstBinder(BracketBinder):
ty: Term
ident: str | None

def __init__(self, ty: Term, ident: str | None = None):
object.__setattr__(self, 'ty', ty)
object.__setattr__(self, 'ident', ident)

def __str__(self) -> str:
ident = f'{self.ident} : ' if self.ident else ''
return f'[{ident}{self.ty}]'


@final
@dataclass(frozen=True)
class Term:
term: str # TODO: refine

def __str__(self) -> str:
return self.term


@final
@dataclass(frozen=True)
class Modifiers:
attrs: tuple[Attr, ...]
visibility: Visibility | None
noncomputable: bool
unsafe: bool
totality: Totality | None

def __init__(
self,
*,
attrs: Iterable[str | Attr] | None = None,
visibility: str | Visibility | None = None,
noncomputable: bool | None = None,
unsafe: bool | None = None,
totality: str | Totality | None = None,
):
attrs = tuple(Attr(attr) if isinstance(attr, str) else attr for attr in attrs) if attrs is not None else ()
visibility = Visibility(visibility) if isinstance(visibility, str) else visibility
noncomputable = bool(noncomputable)
unsafe = bool(unsafe)
totality = Totality(totality) if isinstance(totality, str) else totality
object.__setattr__(self, 'attrs', attrs)
object.__setattr__(self, 'visibility', visibility)
object.__setattr__(self, 'noncomputable', noncomputable)
object.__setattr__(self, 'unsafe', unsafe)
object.__setattr__(self, 'totality', totality)

def __str__(self) -> str:
chunks = []
if self.attrs:
attrs = ', '.join(str(attr) for attr in self.attrs)
chunks.append(f'@[{attrs}]')
if self.visibility:
chunks.append(self.visibility.value)
if self.noncomputable:
chunks.append('noncomputable')
if self.unsafe:
chunks.append('unsafe')
if self.totality:
chunks.append(self.totality.value)
return ' '.join(chunks)


@final
@dataclass(frozen=True)
class Attr:
attr: str
kind: AttrKind | None

def __init__(self, attr: str, kind: AttrKind | None = None):
object.__setattr__(self, 'attr', attr)
object.__setattr__(self, 'kind', kind)

def __str__(self) -> str:
if self.kind:
return f'{self.kind.value} {self.attr}'
return self.attr


class AttrKind(Enum):
SCOPED = 'scoped'
LOCAL = 'local'


class Visibility(Enum):
PRIVATE = 'private'
PROTECTED = 'protected'


class Totality(Enum):
PARTIAL = 'partial'
NONREC = 'nonrec'
Loading
Loading