-
Notifications
You must be signed in to change notification settings - Fork 96
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
vendor aie-python-extras (i.e., don't pip install from github) (#1255)
- Loading branch information
1 parent
7233edb
commit a6ba54b
Showing
15 changed files
with
6,813 additions
and
32 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
import contextlib | ||
from contextlib import ExitStack, contextmanager | ||
from dataclasses import dataclass | ||
from typing import Optional | ||
|
||
from .. import ir | ||
|
||
|
||
@dataclass | ||
class MLIRContext: | ||
context: ir.Context | ||
module: ir.Module | ||
|
||
def __str__(self): | ||
return str(self.module) | ||
|
||
|
||
@contextmanager | ||
def mlir_mod_ctx( | ||
src: Optional[str] = None, | ||
context: ir.Context = None, | ||
location: ir.Location = None, | ||
allow_unregistered_dialects=False, | ||
) -> MLIRContext: | ||
if context is None: | ||
context = ir.Context() | ||
if allow_unregistered_dialects: | ||
context.allow_unregistered_dialects = True | ||
with ExitStack() as stack: | ||
stack.enter_context(context) | ||
if location is None: | ||
location = ir.Location.unknown() | ||
stack.enter_context(location) | ||
if src is not None: | ||
module = ir.Module.parse(src) | ||
else: | ||
module = ir.Module.create() | ||
ip = ir.InsertionPoint(module.body) | ||
stack.enter_context(ip) | ||
yield MLIRContext(context, module) | ||
context._clear_live_operations() | ||
|
||
|
||
class RAIIMLIRContext: | ||
context: ir.Context | ||
location: ir.Location | ||
|
||
def __init__(self, location: Optional[ir.Location] = None): | ||
self.context = ir.Context() | ||
self.context.__enter__() | ||
if location is None: | ||
location = ir.Location.unknown() | ||
self.location = location | ||
self.location.__enter__() | ||
|
||
def __del__(self): | ||
self.location.__exit__(None, None, None) | ||
self.context.__exit__(None, None, None) | ||
# i guess the extension gets destroyed before this object sometimes? | ||
if ir is not None: | ||
assert ir.Context is not self.context | ||
|
||
|
||
class ExplicitlyManagedModule: | ||
module: ir.Module | ||
_ip: ir.InsertionPoint | ||
|
||
def __init__(self): | ||
self.module = ir.Module.create() | ||
self._ip = ir.InsertionPoint(self.module.body) | ||
self._ip.__enter__() | ||
|
||
def finish(self): | ||
self._ip.__exit__(None, None, None) | ||
return self.module | ||
|
||
def __str__(self): | ||
return str(self.module) | ||
|
||
|
||
@contextlib.contextmanager | ||
def enable_multithreading(context=None): | ||
from ..ir import Context | ||
|
||
if context is None: | ||
context = Context.current | ||
context.enable_multithreading(True) | ||
yield | ||
context.enable_multithreading(False) | ||
|
||
|
||
@contextlib.contextmanager | ||
def disable_multithreading(context=None): | ||
from ..ir import Context | ||
|
||
if context is None: | ||
context = Context.current | ||
|
||
context.enable_multithreading(False) | ||
yield | ||
context.enable_multithreading(True) |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
from functools import cached_property, reduce | ||
from typing import Tuple | ||
|
||
import numpy as np | ||
|
||
from ....ir import DenseElementsAttr, ShapedType, Type | ||
|
||
S = ShapedType.get_dynamic_size() | ||
|
||
|
||
# mixin that requires `is_constant` | ||
class ShapedValue: | ||
@cached_property | ||
def literal_value(self) -> np.ndarray: | ||
if not self.is_constant: | ||
raise ValueError("Can't build literal from non-constant value") | ||
return np.array(DenseElementsAttr(self.owner.opview.value), copy=False) | ||
|
||
@cached_property | ||
def _shaped_type(self) -> ShapedType: | ||
return ShapedType(self.type) | ||
|
||
def has_static_shape(self) -> bool: | ||
return self._shaped_type.has_static_shape | ||
|
||
def has_rank(self) -> bool: | ||
return self._shaped_type.has_rank | ||
|
||
@cached_property | ||
def rank(self) -> int: | ||
return self._shaped_type.rank | ||
|
||
@cached_property | ||
def shape(self) -> Tuple[int, ...]: | ||
return tuple(self._shaped_type.shape) | ||
|
||
@cached_property | ||
def n_elements(self) -> int: | ||
assert self.has_static_shape() | ||
return reduce(lambda acc, v: acc * v, self._shaped_type.shape, 1) | ||
|
||
@cached_property | ||
def dtype(self) -> Type: | ||
return self._shaped_type.element_type |
Oops, something went wrong.