From a63b8cbeb0c509a99e2c8a55e7faed014dc2bb51 Mon Sep 17 00:00:00 2001 From: LukasHedegaard Date: Tue, 21 Jun 2022 14:36:09 +0200 Subject: [PATCH 1/2] Fix quantization tests & fix torch & transfmrs req --- nn_pruning/modules/quantization.py | 8 +++----- setup.py | 2 +- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/nn_pruning/modules/quantization.py b/nn_pruning/modules/quantization.py index f4b92184..2b198a80 100644 --- a/nn_pruning/modules/quantization.py +++ b/nn_pruning/modules/quantization.py @@ -12,8 +12,7 @@ prepare_fx, prepare_qat_fx, ) -from transformers.modeling_fx_utils import symbolic_trace - +from transformers.utils.fx import symbolic_trace from .quantization_config import create_qconfig @@ -113,11 +112,10 @@ def _prepare( else: model.eval() - traced = symbolic_trace( - model, input_names=input_names, batch_size=batch_size, sequence_length=sequence_length, num_choices=num_choices - ) + traced = symbolic_trace(model, input_names=input_names) change_attention_mask_value(traced) + # traced=model prepare_custom_config_dict = {"preserved_attributes": ["config", "dummy_inputs"]} prepared_model = torch_prepare_fn(traced, qconfig_dict, prepare_custom_config_dict) diff --git a/setup.py b/setup.py index 1690cfda..c5701e36 100644 --- a/setup.py +++ b/setup.py @@ -33,7 +33,7 @@ def combine_requirements(base_keys): author_email="", license="MIT", packages=["nn_pruning", "nn_pruning.modules"], - install_requires=["click", "transformers>=4.3.0", "torch>=1.6", "scikit-learn>=0.24"], + install_requires=["click", "transformers==4.15.0", "torch==1.9", "scikit-learn>=0.24"], extras_require=extras, test_suite="nose.collector", tests_require=["nose", "nose-cover3"], From 9fe9901490e27f5e4ebf570ad6a8a2e9b314639a Mon Sep 17 00:00:00 2001 From: LukasHedegaard Date: Tue, 21 Jun 2022 14:52:45 +0200 Subject: [PATCH 2/2] Move fixed transformers & torch reqs to tests inst --- setup.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/setup.py b/setup.py index c5701e36..a2a69bbb 100644 --- a/setup.py +++ b/setup.py @@ -5,14 +5,23 @@ def readme(): with open("README.md") as f: return f.read() + extras = { - "tests": ["pytest"], - "examples": ["numpy>=1.2.0", "datasets>=1.4.1", "ipywidgets>=7.6.3", "matplotlib>=3.3.4", "pandas>=1.2.3"], + "tests": ["pytest", "transformers==4.15.0", "torch==1.9"], + "examples": [ + "numpy>=1.2.0", + "datasets>=1.4.1", + "ipywidgets>=7.6.3", + "matplotlib>=3.3.4", + "pandas>=1.2.3", + ], } + def combine_requirements(base_keys): return list(set(k for v in base_keys for k in extras[v])) + extras["dev"] = combine_requirements([k for k in extras if k != "examples"]) @@ -33,7 +42,7 @@ def combine_requirements(base_keys): author_email="", license="MIT", packages=["nn_pruning", "nn_pruning.modules"], - install_requires=["click", "transformers==4.15.0", "torch==1.9", "scikit-learn>=0.24"], + install_requires=["click", "transformers>=4.3", "torch>=1.6", "scikit-learn>=0.24"], extras_require=extras, test_suite="nose.collector", tests_require=["nose", "nose-cover3"],