Skip to content

Commit

Permalink
Update stablehlo.py
Browse files Browse the repository at this point in the history
  • Loading branch information
ManfeiBai authored Nov 29, 2023
1 parent 8384841 commit b60eabe
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions torch_xla/stablehlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,14 @@ class StableHLOModelBundle:
stablehlo_funcs: List[StableHLOFunc]


@dataclass
class StableHLOExportOptions:
include_human_readable_text: bool = True
override_tracing_arguments: Optional[Tuple[Any]] = None
override_tracing_kwargs: Optional[Mapping[str, Any]] = None
save_weights: bool = True


class XLAExportInterpreter(torch.fx.Interpreter):

def __init__(self, module, device):
Expand Down Expand Up @@ -456,14 +464,6 @@ def _load_program_bundle(stablehlo_dir: os.PathLike) -> StableHLOModelBundle:
state_dict=state_dict)


@dataclass
class StableHLOExportOptions:
include_human_readable_text: bool = True
override_tracing_arguments: Optional[Tuple[Any]] = None
override_tracing_kwargs: Optional[Mapping[str, Any]] = None
save_weights: bool = True


def save_as_stablehlo(exported_model: 'ExportedProgram',
stablehlo_dir: os.PathLike,
options: Optional[StableHLOExportOptions] = None):
Expand Down

0 comments on commit b60eabe

Please sign in to comment.