Skip to content

Commit

Permalink
InputGen: introduce random manager
Browse files Browse the repository at this point in the history
Summary:
Introduces a random manager in InputGen. This allows to generate reproducible data, by seeding the random manager.
```
from inputgen.utils.random_manager import random_manager
random_manager.seed(1729)
```

Differential Revision: D59668295
  • Loading branch information
manuelcandales authored and facebook-github-bot committed Jul 12, 2024
1 parent a88dd7e commit 3003a94
Show file tree
Hide file tree
Showing 8 changed files with 135 additions and 24 deletions.
35 changes: 35 additions & 0 deletions examples/random_seed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import torch
from inputgen.argtuple.gen import ArgumentTupleGenerator
from inputgen.utils.random_manager import random_manager
from specdb.db import SpecDictDB


def main():
# example to seed all random number generators
random_manager.seed(1729)

spec = SpecDictDB["add.Tensor"]
op = torch.ops.aten.add.Tensor
for ix, (posargs, inkwargs, outargs) in enumerate(
ArgumentTupleGenerator(spec).gen()
):
op(*posargs, **inkwargs, **outargs)
print(
posargs[0].shape,
posargs[0].dtype,
posargs[1].shape,
posargs[1].dtype,
inkwargs["alpha"],
)
if ix == 1:
print(posargs[0])


if __name__ == "__main__":
main()
8 changes: 5 additions & 3 deletions inputgen/argument/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import random
from typing import Any, List, Optional, Tuple, Union

import torch
Expand All @@ -13,6 +12,7 @@
from inputgen.attribute.model import Attribute
from inputgen.attribute.solve import AttributeSolver
from inputgen.specs.model import Constraint, ConstraintSuffix
from inputgen.utils.random_manager import random_manager as rm
from inputgen.variable.type import ScalarDtype


Expand Down Expand Up @@ -60,7 +60,9 @@ def gen_structure_with_depth_and_length(
yield from self.gen_structure_with_depth(depth, focus, length)
return

focus_ixs = range(length) if focus == attr else (random.choice(range(length)),)
focus_ixs = (
range(length) if focus == attr else (rm.get_random().choice(range(length)),)
)
for focus_ix in focus_ixs:
values = [()]
for ix in range(length):
Expand Down Expand Up @@ -241,7 +243,7 @@ def gen_value_spaces(self, focus, dtype, struct):
if focus == Attribute.VALUE:
return [v.space for v in variables]
else:
return [random.choice(variables).space]
return [rm.get_random().choice(variables).space]

def gen(self, focus):
# TODO(mcandales): Enable Tensor List generation
Expand Down
39 changes: 33 additions & 6 deletions inputgen/argument/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import torch
from inputgen.argument.engine import MetaArg
from inputgen.utils.random_manager import random_manager
from inputgen.variable.gen import VariableGenerator
from inputgen.variable.space import VariableSpace
from torch.testing._internal.common_dtype import floating_types, integral_types
Expand Down Expand Up @@ -41,6 +42,8 @@ def gen(self):
)

def get_random_tensor(self, size, dtype, high=None, low=None):
torch_rng = random_manager.get_torch()

if low is None and high is None:
low = -100
high = 100
Expand All @@ -55,7 +58,9 @@ def get_random_tensor(self, size, dtype, high=None, low=None):
elif not self.space.contains(1):
return torch.full(size, False, dtype=dtype)
else:
return torch.randint(low=0, high=2, size=size, dtype=dtype)
return torch.randint(
low=0, high=2, size=size, dtype=dtype, generator=torch_rng
)

if dtype in integral_types():
low = math.ceil(low)
Expand All @@ -68,16 +73,38 @@ def get_random_tensor(self, size, dtype, high=None, low=None):

if dtype == torch.uint8:
if not self.space.contains(0):
return torch.randint(low=max(1, low), high=high, size=size, dtype=dtype)
return torch.randint(
low=max(1, low),
high=high,
size=size,
dtype=dtype,
generator=torch_rng,
)
else:
return torch.randint(low=max(0, low), high=high, size=size, dtype=dtype)
return torch.randint(
low=max(0, low),
high=high,
size=size,
dtype=dtype,
generator=torch_rng,
)

t = torch.randint(low=low, high=high, size=size, dtype=dtype)
t = torch.randint(
low=low, high=high, size=size, dtype=dtype, generator=torch_rng
)
if not self.space.contains(0):
if high > 0:
pos = torch.randint(low=max(1, low), high=high, size=size, dtype=dtype)
pos = torch.randint(
low=max(1, low),
high=high,
size=size,
dtype=dtype,
generator=torch_rng,
)
else:
pos = torch.randint(low=low, high=0, size=size, dtype=dtype)
pos = torch.randint(
low=low, high=0, size=size, dtype=dtype, generator=torch_rng
)
t = torch.where(t == 0, pos, t)

if dtype in integral_types():
Expand Down
4 changes: 2 additions & 2 deletions inputgen/attribute/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from inputgen.attribute.solve import AttributeSolver
from inputgen.specs.model import Constraint
from inputgen.variable.gen import VariableGenerator
from inputgen.variable.type import ScalarDtype
from inputgen.variable.type import ScalarDtype, sort_values_of_type


class AttributeEngine(AttributeSolver):
Expand Down Expand Up @@ -51,4 +51,4 @@ def gen(self, focus: Attribute, *args):
if len(vals) == 0:
vals = VariableGenerator(variable.space).gen(num)
gen_vals.update(vals)
return gen_vals
return sort_values_of_type(variable.vtype, gen_vals)
Empty file added inputgen/utils/__init__.py
Empty file.
33 changes: 33 additions & 0 deletions inputgen/utils/random_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import random

import torch


class RandomManager:
def __init__(self):
self._rng = random.Random()
self._torch_rng = torch.Generator()

def seed(self, seed):
"""
Seeds the random number generators for random and torch.
"""
self._rng.seed(seed)
self._torch_rng.manual_seed(seed)

def get_random(self):
# self._rng.seed(42)
return self._rng

def get_torch(self):
# self._torch_rng.manual_seed(42)
return self._torch_rng


random_manager = RandomManager()
28 changes: 16 additions & 12 deletions inputgen/variable/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@
# LICENSE file in the root directory of this source tree.

import math
import random
from typing import Any, List, Optional, Set, Union

from inputgen.utils.random_manager import random_manager as rm
from inputgen.variable.constants import BOUND_ON_INF, INT64_MAX, INT64_MIN
from inputgen.variable.space import Interval, Intervals, VariableSpace
from inputgen.variable.type import sort_values_of_type
from inputgen.variable.utils import nextdown, nextup


Expand Down Expand Up @@ -51,7 +52,7 @@ def gen_float_from_interval(r: Interval) -> Optional[float]:
elif lower > upper:
return None
else:
return random.uniform(lower, upper)
return rm.get_random().uniform(lower, upper)


def gen_min_float_from_intervals(rs: Intervals) -> Optional[float]:
Expand All @@ -69,7 +70,7 @@ def gen_max_float_from_intervals(rs: Intervals) -> Optional[float]:
def gen_float_from_intervals(rs: Intervals) -> Optional[float]:
if rs.empty():
return None
r = random.choice(rs.intervals)
r = rm.get_random().choice(rs.intervals)
return gen_float_from_interval(r)


Expand Down Expand Up @@ -112,7 +113,7 @@ def gen_int_from_interval(r: Interval) -> Optional[int]:
elif upper is None:
upper = max(lower, 0) + BOUND_ON_INF
assert lower is not None and upper is not None
return random.randint(lower, upper)
return rm.get_random().randint(lower, upper)


def gen_min_int_from_intervals(rs: Intervals) -> Optional[int]:
Expand All @@ -133,7 +134,7 @@ def gen_int_from_intervals(rs: Intervals) -> Optional[int]:
intervals_with_ints = [r for r in rs.intervals if r.contains_int()]
if len(intervals_with_ints) == 0:
return None
r = random.choice(intervals_with_ints)
r = rm.get_random().choice(intervals_with_ints)
return gen_int_from_interval(r)


Expand All @@ -147,6 +148,12 @@ def __init__(self, space: VariableSpace):
self.vtype = space.vtype
self.space = space

def _sorted(self, values: Set[Any]) -> List[Any]:
return sort_values_of_type(self.vtype, values)

def _sample(self, values: Set[Any], num: int) -> List[Any]:
return rm.get_random().sample(self._sorted(values), num)

def gen_min(self) -> Any:
"""Returns the minimum value of the space."""
if self.space.empty() or self.vtype not in [bool, int, float]:
Expand Down Expand Up @@ -221,7 +228,7 @@ def gen_edges_non_extreme(self, num: int = 2) -> Set[Any]:
edges_not_extreme = self.gen_edges() - self.gen_extremes()
if num >= len(edges_not_extreme):
return edges_not_extreme
return set(random.sample(list(edges_not_extreme), num))
return set(self._sample(edges_not_extreme, num))

def gen_non_edges(self, num: int = 2) -> Set[Any]:
"""Generates non-edge (or interior) values of the space."""
Expand All @@ -232,7 +239,7 @@ def gen_non_edges(self, num: int = 2) -> Set[Any]:
if self.space.discrete.initialized:
vals = self.space.discrete.values - edge_or_extreme_vals
if num < len(vals):
vals = set(random.sample(list(vals), num))
vals = set(self._sample(vals, num))
else:
for _ in range(100):
v: Optional[Union[int, float]] = None
Expand Down Expand Up @@ -269,11 +276,8 @@ def gen_balanced(self, num: int = 6) -> Set[Any]:

if num >= len(balanced):
return balanced
return set(random.sample(list(balanced), num))
return set(self._sample(balanced, num))

def gen(self, num: int = 6) -> List[Any]:
"""Generates a sorted (if applicable), balanced sample of the space."""
vals = list(self.gen_balanced(num))
if self.vtype in [bool, int, float, str]:
return sorted(vals)
return vals
return sort_values_of_type(self.vtype, self.gen_balanced(num))
12 changes: 11 additions & 1 deletion inputgen/variable/type.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import math
from enum import Enum
from typing import Any
from typing import Any, List, Set

import torch

Expand Down Expand Up @@ -93,3 +93,13 @@ def convert_to_vtype(vtype: type, v: Any) -> Any:
if vtype == float:
return float(v)
return v


def sort_values_of_type(vtype: type, values: Set[Any]) -> List[Any]:
if vtype in [bool, int, float, str, tuple]:
return sorted(values)
if vtype == torch.dtype:
return [v for v in SUPPORTED_TENSOR_DTYPES if v in values]
if vtype == ScalarDtype:
return [v for v in ScalarDtype if v in values]
return list(values)

0 comments on commit 3003a94

Please sign in to comment.