Skip to content

Commit

Permalink
Merge pull request #4 from eagomez2/develop
Browse files Browse the repository at this point in the history
Develop
  • Loading branch information
eagomez2 authored Sep 25, 2024
2 parents f1d606b + 831c5cb commit f3cfd52
Show file tree
Hide file tree
Showing 5 changed files with 85 additions and 8 deletions.
4 changes: 3 additions & 1 deletion docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -125,4 +125,6 @@ By default, all methods support all modules as long as these are instances of `t
| `torch.nn.AdaptiveMaxPool2d` | :material-check: | 0.0.1 |
| `torch.nn.MaxPool1d` | :material-check: | 0.0.1 |
| `torch.nn.MaxPool2d` | :material-check: | 0.0.1 |
| `torch.nn.LayerNorm` | :material-check: | 0.0.1 |
| `torch.nn.LayerNorm` | :material-check: | 0.0.1 |
| `torch.nn.BatchNorm1d` | :material-check: | 0.0.4 |
| `torch.nn.BatchNorm2d` | :material-check: | 0.0.4 |
8 changes: 4 additions & 4 deletions docs/modules/layernorm.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ Where
## Complexity
The complexity of a `torch.nn.LayerNorm` layer can be divided into two parts: The aggregated statistics calculation (i.e. mean and standard deviation) and the affine transformation applied by $\gamma$ and $\beta$ if `elementwise_affine=True`.

## Aggregated statistics
### Aggregated statistics
The complexity of the mean corresponds to the sum of all elements in the last $D$ dimensions of the input tensor $x$ and the division of that number by the total number of elements. As an example, if `normalized_shape=(3, 5)` then there are 14 additions and 1 division. This also corresponds to the product of the dimensions involved in `normalized_shape`.

$$
Expand Down Expand Up @@ -63,7 +63,7 @@ $$
\end{equation}
$$

## Elementwise affine
### Elementwise affine
If `elementwise_affine=True`, there is an element-wise multiplication by $\gamma$. If `bias=True`, there is also an element-wise addition by $\beta$. Therefore the whole complexity of affine transformations is

$$
Expand All @@ -82,15 +82,15 @@ $$

when `bias=True`.

## Batch size
### Batch size
So far we have not included the batch size $N$, which in this case could be defined as all other dimensions that are not $D$. This means, those that are not included in `normalized_shape`.

!!! note
Please note that $N$ here corresponds to all dimensions not included in `normalized_shape`, which is different from the definition ot $N$ in `torch.var` which corresponds to the number of elements in the input tensor of that function.

The batch size $N$ multiplies all previously calculated operations by a factor $\eta$ corresponding to the multiplication of the remaining dimensions. For example, if the input tensor has size `(2, 3, 5)` and `normalized_shape=(3, 5)`, then $\eta$ is $2$.

## Total complexity
### Total complexity
Including all previously calculated factor, the total complexity can be summarized as

$$
Expand Down
2 changes: 1 addition & 1 deletion src/moduleprofiler/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"Module profiler"
__version__ = "0.0.3"
__version__ = "0.0.4"

__all__ = [
"get_default_ops_map",
Expand Down
59 changes: 59 additions & 0 deletions src/moduleprofiler/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,14 @@ def _default_ops_fn(
return None


def _excluded_ops_fn(
module: nn.Module,
input: Tuple[torch.Tensor],
output: torch.Tensor
) -> Any:
return None


def _identity_ops_fn(
module: nn.Identity,
input: Tuple[torch.Tensor],
Expand Down Expand Up @@ -584,11 +592,60 @@ def _avgpool2d_ops_fn(
)


def _batchnorm1d_ops_fn(
module: nn.BatchNorm1d,
input: Tuple[torch.Tensor],
output: torch.Tensor
) -> int:
if input[0].ndim == 2:
num_elements = input[0].size(0)

elif input[0].ndim == 3:
num_elements = input[0].size(0) * input[0].size(-1)

else:
raise AssertionError

if not module.affine:
total_ops = 5 * num_elements + 4

else:
total_ops = 7 * num_elements + 4

# Add num_features C
total_ops *= module.num_features

return total_ops


def _batchnorm2d_ops_fn(
module: nn.BatchNorm2d,
input: Tuple[torch.Tensor],
output: torch.Tensor
) -> int:
num_elements = input[0].size(0) * input[0].size(-1) * input[0].size(-2)

if not module.affine:
total_ops = 5 * num_elements + 4

else:
total_ops = 7 * num_elements + 4

# Add num_features C
total_ops *= module.num_features

return total_ops



def get_default_ops_map() -> dict:
return {
# Default method
"default": _default_ops_fn,

# Excluded module method
"excluded": _excluded_ops_fn,

# Layers
nn.Identity: _identity_ops_fn,
nn.Linear: _linear_ops_fn,
Expand All @@ -602,6 +659,8 @@ def get_default_ops_map() -> dict:
nn.LSTM: _lstm_ops_fn,

# Norm
nn.BatchNorm1d: _batchnorm1d_ops_fn,
nn.BatchNorm2d: _batchnorm2d_ops_fn,
nn.LayerNorm: _layernorm_ops_fn,

# Pooling
Expand Down
20 changes: 18 additions & 2 deletions src/moduleprofiler/profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
Any,
Callable,
Dict,
List,
Optional,
Tuple,
Union
Expand Down Expand Up @@ -43,6 +44,8 @@ class ModuleProfiler:
their corresponding functions useed to trace the its size.
ops_fn_map (dict): Dictionary containing a map between modules and
their corresponding function to estimate the number of operations.
exclude_from_ops (Optional[List[nn.Module]]): Modules to exclude from
ops estimations.
ts_fmt (str): Timestamp format used to print messages if
`verbose=True`.
verbose (bool): If ``True``, enabled verbose output mode.
Expand All @@ -56,6 +59,7 @@ def __init__(
inference_end_attr: str = "__inference_end__",
io_size_fn_map: dict = get_default_io_size_map(),
ops_fn_map: dict = get_default_ops_map(),
exclude_from_ops: Optional[List[nn.Module]] = None,
ts_fmt: str = "%Y-%m-%d %H:%M:%S",
verbose: bool = False
) -> None:
Expand All @@ -69,6 +73,7 @@ def __init__(
self.inference_end_attr = inference_end_attr
self.io_size_fn_map = io_size_fn_map
self.ops_fn_map = ops_fn_map
self.exclude_from_ops = exclude_from_ops
self.verbose = verbose
self._logger = Logger(ts_fmt=ts_fmt)
self._hook_handles = []
Expand Down Expand Up @@ -312,7 +317,14 @@ def _ops_fn(
"""
# Obtain method to estimate ops
if module.__class__ in self.ops_fn_map:
ops_fn = self.ops_fn_map[type(module)]
if (
self.exclude_from_ops is not None
and module.__class__ in self.exclude_from_ops
):
ops_fn = self.ops_fn_map["excluded"]

else:
ops_fn = self.ops_fn_map[type(module)]

else:
ops_fn = self.ops_fn_map["default"]
Expand Down Expand Up @@ -368,7 +380,11 @@ def count_params(
data[n] = {
"type": m.__class__.__name__,
"trainable_params": 0,
"nontrainable_params": 0
"trainable_params_dtype": None,
"trainable_params_size_bits": 0,
"nontrainable_params": 0,
"nontrainable_params_dtype": None,
"nontrainable_params_size_bits": 0
}

for p in m.parameters():
Expand Down

0 comments on commit f3cfd52

Please sign in to comment.