From d4dd53d62ec58f7020f5f1f02dbc099ae99fe10a Mon Sep 17 00:00:00 2001 From: Zhiyuan Chen Date: Sun, 24 Mar 2024 00:13:46 +0800 Subject: [PATCH] add build and pipeline --- .codespell-whitelist.txt | 1 + .github/merge_rules.yaml | 7 + .github/workflows/push.yaml | 39 +++ .gitignore | 416 +++++++++++++++++++++++++++++++ .pre-commit-config.yaml | 76 ++++++ pyproject.toml | 98 ++++++++ rnabert/__init__.py | 1 - rnabert/configuration_rnabert.py | 3 +- rnabert/modeling_rnabert.py | 97 +++---- rnabert/tokenization_rnabert.py | 7 +- setup.py | 3 + tox.ini | 7 + 12 files changed, 685 insertions(+), 70 deletions(-) create mode 100644 .codespell-whitelist.txt create mode 100644 .github/merge_rules.yaml create mode 100644 .github/workflows/push.yaml create mode 100755 .gitignore create mode 100644 .pre-commit-config.yaml create mode 100644 pyproject.toml create mode 100755 setup.py create mode 100644 tox.ini diff --git a/.codespell-whitelist.txt b/.codespell-whitelist.txt new file mode 100644 index 0000000..1449575 --- /dev/null +++ b/.codespell-whitelist.txt @@ -0,0 +1 @@ +datas \ No newline at end of file diff --git a/.github/merge_rules.yaml b/.github/merge_rules.yaml new file mode 100644 index 0000000..ab8ae87 --- /dev/null +++ b/.github/merge_rules.yaml @@ -0,0 +1,7 @@ +- name: merge + patterns: + - rnabert/** + approved_by: + - ZhiyuanChen + mandatory_checks_name: + - push diff --git a/.github/workflows/push.yaml b/.github/workflows/push.yaml new file mode 100644 index 0000000..f00dda5 --- /dev/null +++ b/.github/workflows/push.yaml @@ -0,0 +1,39 @@ +name: push +on: [push, pull_request] +jobs: + lint: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - uses: actions/setup-python@v4 + with: + python-version: 3.x + cache: "pip" + - uses: pre-commit/action@v3.0.0 + release: + if: startsWith(github.event.ref, 'refs/tags/v') + needs: [lint] + environment: pypi + permissions: + contents: write + id-token: write + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - uses: actions/setup-python@v4 + with: + python-version: 3.x + cache: "pip" + - name: Install dependencies for building + run: pip install wheel setuptools_scm + - name: build package + run: python setup.py sdist bdist_wheel + - name: create release + uses: "marvinpinto/action-automatic-releases@latest" + with: + repo_token: "${{ secrets.GITHUB_TOKEN }}" + prerelease: false + files: | + dist/* + - name: publish to PyPI + uses: pypa/gh-action-pypi-publish@release/v1 diff --git a/.gitignore b/.gitignore new file mode 100755 index 0000000..b83e9e8 --- /dev/null +++ b/.gitignore @@ -0,0 +1,416 @@ + +## Ignore Visual Studio temporary files, build results, and +## files generated by popular Visual Studio add-ons. +## +## Get latest from https://github.com/github/gitignore/blob/main/VisualStudio.gitignore + +# User-specific files +*.rsuser +*.suo +*.user +*.userosscache +*.sln.docstates + +# User-specific files (MonoDevelop/Xamarin Studio) +*.userprefs + +# Mono auto generated files +mono_crash.* + +# Build results +[Dd]ebug/ +[Dd]ebugPublic/ +[Rr]elease/ +[Rr]eleases/ +x64/ +x86/ +[Ww][Ii][Nn]32/ +[Aa][Rr][Mm]/ +[Aa][Rr][Mm]64/ +bld/ +[Bb]in/ +[Oo]bj/ +[Ll]og/ +[Ll]ogs/ + +# Visual Studio 2015/2017 cache/options directory +.vs/ +# Visual Studio Code cache/options directory +.vscode/ +# Uncomment if you have tasks that create the project's static files in wwwroot +#wwwroot/ + +# Visual Studio 2017 auto generated files +Generated\ Files/ + +# MSTest test Results +[Tt]est[Rr]esult*/ +[Bb]uild[Ll]og.* + +# NUnit +*.VisualState.xml +TestResult.xml +nunit-*.xml + +# Build Results of an ATL Project +[Dd]ebugPS/ +[Rr]eleasePS/ +dlldata.c + +# Benchmark Results +BenchmarkDotNet.Artifacts/ + +# .NET Core +project.lock.json +project.fragment.lock.json +artifacts/ + +# ASP.NET Scaffolding +ScaffoldingReadMe.txt + +# StyleCop +StyleCopReport.xml + +# Files built by Visual Studio +*_i.c +*_p.c +*_h.h +*.ilk +*.meta +*.obj +*.iobj +*.pch +*.pdb +*.ipdb +*.pgc +*.pgd +*.rsp +*.sbr +*.tlb +*.tli +*.tlh +*.tmp +*.tmp_proj +*_wpftmp.csproj +*.log +*.tlog +*.vspscc +*.vssscc +.builds +*.pidb +*.svclog +*.scc + +# Chutzpah Test files +_Chutzpah* + +# Visual C++ cache files +ipch/ +*.aps +*.ncb +*.opendb +*.opensdf +*.sdf +*.cachefile +*.VC.db +*.VC.VC.opendb + +# Visual Studio profiler +*.psess +*.vsp +*.vspx +*.sap + +# Visual Studio Trace Files +*.e2e + +# TFS 2012 Local Workspace +$tf/ + +# Guidance Automation Toolkit +*.gpState + +# ReSharper is a .NET coding add-in +_ReSharper*/ +*.[Rr]e[Ss]harper +*.DotSettings.user + +# TeamCity is a build add-in +_TeamCity* + +# DotCover is a Code Coverage Tool +*.dotCover + +# AxoCover is a Code Coverage Tool +.axoCover/* +!.axoCover/settings.json + +# Coverlet is a free, cross platform Code Coverage Tool +coverage*.json +coverage*.xml +coverage*.info + +# Visual Studio code coverage results +*.coverage +*.coveragexml + +# NCrunch +_NCrunch_* +.*crunch*.local.xml +nCrunchTemp_* + +# MightyMoose +*.mm.* +AutoTest.Net/ + +# Web workbench (sass) +.sass-cache/ + +# Installshield output folder +[Ee]xpress/ + +# DocProject is a documentation generator add-in +DocProject/buildhelp/ +DocProject/Help/*.HxT +DocProject/Help/*.HxC +DocProject/Help/*.hhc +DocProject/Help/*.hhk +DocProject/Help/*.hhp +DocProject/Help/Html2 +DocProject/Help/html + +# Click-Once directory +publish/ + +# Publish Web Output +*.[Pp]ublish.xml +*.azurePubxml +# Note: Comment the next line if you want to checkin your web deploy settings, +# but database connection strings (with potential passwords) will be unencrypted +*.pubxml +*.publishproj + +# Microsoft Azure Web App publish settings. Comment the next line if you want to +# checkin your Azure Web App publish settings, but sensitive information contained +# in these scripts will be unencrypted +PublishScripts/ + +# NuGet Packages +*.nupkg +# NuGet Symbol Packages +*.snupkg +# The packages folder can be ignored because of Package Restore +**/[Pp]ackages/* +# except build/, which is used as an MSBuild target. +!**/[Pp]ackages/build/ +# Uncomment if necessary however generally it will be regenerated when needed +#!**/[Pp]ackages/repositories.config +# NuGet v3's project.json files produces more ignorable files +*.nuget.props +*.nuget.targets + +# Microsoft Azure Build Output +csx/ +*.build.csdef + +# Microsoft Azure Emulator +ecf/ +rcf/ + +# Windows Store app package directories and files +AppPackages/ +BundleArtifacts/ +Package.StoreAssociation.xml +_pkginfo.txt +*.appx +*.appxbundle +*.appxupload + +# Visual Studio cache files +# files ending in .cache can be ignored +*.[Cc]ache +# but keep track of directories ending in .cache +!?*.[Cc]ache/ + +# Others +ClientBin/ +~$* +*~ +*.dbmdl +*.dbproj.schemaview +*.jfm +*.pfx +*.publishsettings +orleans.codegen.cs + +# Including strong name files can present a security risk +# (https://github.com/github/gitignore/pull/2483#issue-259490424) +#*.snk + +# Since there are multiple workflows, uncomment next line to ignore bower_components +# (https://github.com/github/gitignore/pull/1529#issuecomment-104372622) +#bower_components/ + +# RIA/Silverlight projects +Generated_Code/ + +# Backup & report files from converting an old project file +# to a newer Visual Studio version. Backup files are not needed, +# because we have git ;-) +_UpgradeReport_Files/ +Backup*/ +UpgradeLog*.XML +UpgradeLog*.htm +ServiceFabricBackup/ +*.rptproj.bak + +# SQL Server files +*.mdf +*.ldf +*.ndf + +# Business Intelligence projects +*.rdl.data +*.bim.layout +*.bim_*.settings +*.rptproj.rsuser +*- [Bb]ackup.rdl +*- [Bb]ackup ([0-9]).rdl +*- [Bb]ackup ([0-9][0-9]).rdl + +# Microsoft Fakes +FakesAssemblies/ + +# GhostDoc plugin setting file +*.GhostDoc.xml + +# Node.js Tools for Visual Studio +.ntvs_analysis.dat +node_modules/ + +# Visual Studio 6 build log +*.plg + +# Visual Studio 6 workspace options file +*.opt + +# Visual Studio 6 auto-generated workspace file (contains which files were open etc.) +*.vbw + +# Visual Studio 6 auto-generated project file (contains which files were open etc.) +*.vbp + +# Visual Studio 6 workspace and project file (working project files containing files to include in project) +*.dsw +*.dsp + +# Visual Studio 6 technical files +*.ncb +*.aps + +# Visual Studio LightSwitch build output +**/*.HTMLClient/GeneratedArtifacts +**/*.DesktopClient/GeneratedArtifacts +**/*.DesktopClient/ModelManifest.xml +**/*.Server/GeneratedArtifacts +**/*.Server/ModelManifest.xml +_Pvt_Extensions + +# Paket dependency manager +.paket/paket.exe +paket-files/ + +# FAKE - F# Make +.fake/ + +# CodeRush personal settings +.cr/personal + +# Python Tools for Visual Studio (PTVS) +__pycache__/ +*.pyc + +# Python Packages +build/ +*.egg-info/ +dist/ + +# Cake - Uncomment if you are using it +# tools/** +# !tools/packages.config + +# Tabs Studio +*.tss + +# Telerik's JustMock configuration file +*.jmconfig + +# BizTalk build output +*.btp.cs +*.btm.cs +*.odx.cs +*.xsd.cs + +# OpenCover UI analysis results +OpenCover/ + +# Azure Stream Analytics local run output +ASALocalRun/ + +# MSBuild Binary and Structured Log +*.binlog + +# NVidia Nsight GPU debugger configuration file +*.nvuser + +# MFractors (Xamarin productivity tool) working folder +.mfractor/ + +# Local History for Visual Studio +.localhistory/ + +# Visual Studio History (VSHistory) files +.vshistory/ + +# BeatPulse healthcheck temp database +healthchecksdb + +# Backup folder for Package Reference Convert tool in Visual Studio 2017 +MigrationBackup/ + +# Ionide (cross platform F# VS Code tools) working folder +.ionide/ + +# Fody - auto-generated XML schema +FodyWeavers.xsd + +# VS Code files for those working on multiple tools +.vscode/* +!.vscode/settings.json +!.vscode/tasks.json +!.vscode/launch.json +!.vscode/extensions.json +*.code-workspace + +# Local History for Visual Studio Code +.history/ + +# Windows Installer files from build outputs +*.cab +*.msi +*.msix +*.msm +*.msp + +# JetBrains Rider +*.sln.iml + +# mkdocs +site/ + +# version +**/_version.py + +# pytest +test.json +test.yaml diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..bca1559 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,76 @@ +default_language_version: + python: python3 +repos: + - repo: https://github.com/PSF/black + rev: 24.3.0 + hooks: + - id: black + args: [--safe, --quiet] + - repo: https://github.com/PyCQA/isort + rev: 5.13.2 + hooks: + - id: isort + name: isort + - repo: https://github.com/PyCQA/flake8 + rev: 7.0.0 + hooks: + - id: flake8 + additional_dependencies: + - flake8-bugbear + - flake8-comprehensions + - flake8-simplify + - repo: https://github.com/asottile/pyupgrade + rev: v3.15.1 + hooks: + - id: pyupgrade + args: [--keep-runtime-typing] + - repo: https://github.com/tox-dev/pyproject-fmt + rev: 1.7.0 + hooks: + - id: pyproject-fmt + - repo: https://github.com/pre-commit/mirrors-mypy + rev: v1.8.0 + hooks: + - id: mypy + files: chanfig + additional_dependencies: + - types-PyYaml + - repo: https://github.com/codespell-project/codespell + rev: v2.2.6 + hooks: + - id: codespell + args: [--ignore-words=.codespell-whitelist.txt] + - repo: https://github.com/pre-commit/mirrors-prettier + rev: v4.0.0-alpha.8 + hooks: + - id: prettier + files: chanfig + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.5.0 + hooks: + - id: check-added-large-files + - id: check-ast + - id: check-byte-order-marker + - id: check-builtin-literals + - id: check-case-conflict + - id: check-docstring-first + - id: check-merge-conflict + - id: check-vcs-permalinks + - id: check-symlinks + - id: pretty-format-json + files: chanfig + - id: check-json + - id: check-xml + - id: check-toml + - id: check-yaml + files: chanfig + - id: debug-statements + - id: end-of-file-fixer + files: chanfig + - id: fix-byte-order-marker + - id: fix-encoding-pragma + args: ["--remove"] + - id: mixed-line-ending + args: ["--fix=lf"] + - id: requirements-txt-fixer + - id: trailing-whitespace diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..521db02 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,98 @@ +[build-system] +build-backend = "setuptools.build_meta" +requires = [ + "setuptools", + "setuptools-scm", +] + +[project] +name = "rnabert" +description = "RNA BERT" +readme = "README.md" +keywords = [ + "deep-learning", + "machine-learning", + "RNA", +] +license = {file = "LICENSE"} +maintainers = [ + {name = "Zhiyuan Chen", email = "this@zyc.ai"}, +] +authors = [ + {name = "Zhiyuan Chen", email = "this@zyc.ai"}, +] +requires-python = ">=3.7" +classifiers = [ + "Development Status :: 5 - Production/Stable", + "Intended Audience :: Developers", + "Intended Audience :: Education", + "Intended Audience :: Science/Research", + "Operating System :: OS Independent", + "Programming Language :: Python :: 3 :: Only", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Topic :: Scientific/Engineering :: Artificial Intelligence", +] +dynamic = [ + "version", +] + +[tool.setuptools] +packages = ["rnabert"] + +[tool.setuptools_scm] +write_to = "rnabert/_version.py" + +[tool.black] +line-length = 120 + +[tool.isort] +line_length = 120 +profile = "black" + +[tool.flake8] +max-line-length = 120 + +[tool.pylint.format] +max-line-length = 120 + +[tool.pylint.messages_control] +disable = """ + E0012, + E0401, + R0201, + R0801, +""" + +[tool.pylint.reports] +output-format = "colorized" + +[tool.pylint.main] +fail-under = 9.8 + +[tool.pytest.ini_options] +addopts = "--doctest-modules --cov" + +[tool.coverage.run] +branch = true +include = ["rnabert/**"] + +[tool.coverage.paths] +source = ["rnabert"] + +[tool.coverage.xml] +output = "coverage.xml" + +[tool.coverage.json] +output = "coverage.json" + +[tool.coverage.report] +show_missing = true +fail_under = 80 + +[tool.mypy] +ignore_missing_imports = true diff --git a/rnabert/__init__.py b/rnabert/__init__.py index 4762bd6..ac898ed 100644 --- a/rnabert/__init__.py +++ b/rnabert/__init__.py @@ -2,5 +2,4 @@ from .modeling_rnabert import RnaBertModel from .tokenization_rnabert import RnaBertTokenizer - __all__ = ["RnaBertConfig", "RnaBertModel", "RnaBertTokenizer"] diff --git a/rnabert/configuration_rnabert.py b/rnabert/configuration_rnabert.py index 1e342fa..52c6089 100644 --- a/rnabert/configuration_rnabert.py +++ b/rnabert/configuration_rnabert.py @@ -1,7 +1,6 @@ from transformers.configuration_utils import PretrainedConfig from transformers.utils import logging - logger = logging.get_logger(__name__) @@ -138,4 +137,4 @@ def get_default_vocab_list(): "-", "", "", - ) \ No newline at end of file + ) diff --git a/rnabert/modeling_rnabert.py b/rnabert/modeling_rnabert.py index 8e2f6dd..1092681 100644 --- a/rnabert/modeling_rnabert.py +++ b/rnabert/modeling_rnabert.py @@ -24,15 +24,9 @@ def forward(self, x): class RnaBertEmbeddings(nn.Module): def __init__(self, config): super().__init__() - self.word_embeddings = nn.Embedding( - config.vocab_size, config.hidden_size, padding_idx=0 - ) - self.position_embeddings = nn.Embedding( - config.max_position_embeddings, config.hidden_size - ) - self.token_type_embeddings = nn.Embedding( - config.type_vocab_size, config.hidden_size - ) + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=0) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) self.LayerNorm = RnaBertLayerNorm(config.hidden_size, eps=1e-12) self.dropout = nn.Dropout(config.hidden_dropout_prob) @@ -44,9 +38,7 @@ def forward(self, input_ids, token_type_ids=None): token_type_embeddings = self.token_type_embeddings(token_type_ids) seq_length = input_ids.size(1) - position_ids = torch.arange( - seq_length, dtype=torch.long, device=input_ids.device - ) + position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) position_ids = position_ids.unsqueeze(0).expand_as(input_ids) position_embeddings = self.position_embeddings(position_ids) @@ -69,18 +61,13 @@ def __init__(self, config): self.output = RnaBertOutput(config) def forward(self, hidden_states, attention_mask, attention_show_flg=False): - if attention_show_flg == True: - attention_output, attention_probs = self.attention( - hidden_states, attention_mask, attention_show_flg - ) + if attention_show_flg: + attention_output, attention_probs = self.attention(hidden_states, attention_mask, attention_show_flg) intermediate_output = self.intermediate(attention_output) layer_output = self.output(intermediate_output, attention_output) return layer_output, attention_probs - - elif attention_show_flg == False: - attention_output = self.attention( - hidden_states, attention_mask, attention_show_flg - ) + else: + attention_output = self.attention(hidden_states, attention_mask, attention_show_flg) intermediate_output = self.intermediate(attention_output) layer_output = self.output(intermediate_output, attention_output) return layer_output # [batch_size, seq_length, hidden_size] @@ -93,17 +80,12 @@ def __init__(self, config): self.output = RnaBertSelfOutput(config) def forward(self, input_tensor, attention_mask, attention_show_flg=False): - if attention_show_flg == True: - self_output, attention_probs = self.selfattn( - input_tensor, attention_mask, attention_show_flg - ) + if attention_show_flg: + self_output, attention_probs = self.selfattn(input_tensor, attention_mask, attention_show_flg) attention_output = self.output(self_output, input_tensor) return attention_output, attention_probs - - elif attention_show_flg == False: - self_output = self.selfattn( - input_tensor, attention_mask, attention_show_flg - ) + else: + self_output = self.selfattn(input_tensor, attention_mask, attention_show_flg) attention_output = self.output(self_output, input_tensor) return attention_output @@ -154,9 +136,9 @@ def forward(self, hidden_states, attention_mask, attention_show_flg=False): new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) context_layer = context_layer.view(*new_context_layer_shape) - if attention_show_flg == True: + if attention_show_flg: return context_layer, attention_probs - elif attention_show_flg == False: + else: return context_layer @@ -213,9 +195,7 @@ def forward(self, hidden_states, input_tensor): class RnaBertEncoder(nn.Module): def __init__(self, config): super().__init__() - self.layer = nn.ModuleList( - [RnaBertLayer(config) for _ in range(config.num_hidden_layers)] - ) + self.layer = nn.ModuleList([RnaBertLayer(config) for _ in range(config.num_hidden_layers)]) # self.layer = nn.ModuleList([nn.Linear(config.hidden_size, config.hidden_size) # for _ in range(config.num_hidden_layers)]) @@ -227,24 +207,20 @@ def forward( attention_show_flg=False, ): all_encoder_layers = [] - for i, layer_module in enumerate(self.layer): - if attention_show_flg == True: - hidden_states, attention_probs = layer_module( - hidden_states, attention_mask, attention_show_flg - ) - elif attention_show_flg == False: - hidden_states = layer_module( - hidden_states, attention_mask, attention_show_flg - ) + for layer in self.layer: + if attention_show_flg: + hidden_states, attention_probs = layer(hidden_states, attention_mask, attention_show_flg) + else: + hidden_states = layer(hidden_states, attention_mask, attention_show_flg) if output_all_encoded_layers: all_encoder_layers.append(hidden_states) if not output_all_encoded_layers: all_encoder_layers.append(hidden_states) - if attention_show_flg == True: + if attention_show_flg: return all_encoder_layers, attention_probs - elif attention_show_flg == False: + else: return all_encoder_layers @@ -293,6 +269,7 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) + class RnaBertModel(RnaBertPreTrainedModel): def __init__(self, config): @@ -321,15 +298,14 @@ def forward( embedding_output = self.embeddings(input_ids, token_type_ids) - if attention_show_flg == True: + if attention_show_flg: encoded_layers, attention_probs = self.encoder( embedding_output, extended_attention_mask, output_all_encoded_layers, attention_show_flg, ) - - elif attention_show_flg == False: + else: encoded_layers = self.encoder( embedding_output, extended_attention_mask, @@ -342,9 +318,9 @@ def forward( if not output_all_encoded_layers: encoded_layers = encoded_layers[-1] - if attention_show_flg == True: + if attention_show_flg: return encoded_layers, pooled_output, attention_probs - elif attention_show_flg == False: + else: return encoded_layers, pooled_output @@ -373,9 +349,7 @@ def __init__(self, config): self.transform = RnaBertPredictionHeadTransform(config) - self.decoder = nn.Linear( - in_features=config.hidden_size, out_features=config.vocab_size, bias=False - ) + self.decoder = nn.Linear(in_features=config.hidden_size, out_features=config.vocab_size, bias=False) self.bias = nn.Parameter(torch.zeros(config.vocab_size)) def forward(self, hidden_states): @@ -425,25 +399,22 @@ def forward( attention_mask=None, attention_show_flg=False, ): - if attention_show_flg == False: - encoded_layers, pooled_output = self.bert( + if attention_show_flg: + encoded_layers, pooled_output, attention_probs = self.bert( input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False, - attention_show_flg=False, + attention_show_flg=True, ) - else: - encoded_layers, pooled_output, attention_probs = self.bert( + encoded_layers, pooled_output = self.bert( input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False, - attention_show_flg=True, + attention_show_flg=False, ) - prediction_scores, prediction_scores_ss, seq_relationship_score = self.cls( - encoded_layers, pooled_output - ) + prediction_scores, prediction_scores_ss, seq_relationship_score = self.cls(encoded_layers, pooled_output) return prediction_scores, prediction_scores_ss, encoded_layers diff --git a/rnabert/tokenization_rnabert.py b/rnabert/tokenization_rnabert.py index 2a2c2af..efcecff 100644 --- a/rnabert/tokenization_rnabert.py +++ b/rnabert/tokenization_rnabert.py @@ -4,7 +4,6 @@ from transformers.tokenization_utils import PreTrainedTokenizer from transformers.utils import logging - logger = logging.get_logger(__name__) VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"} @@ -21,9 +20,9 @@ def load_vocab_file(vocab_file): - with open(vocab_file, "r") as f: + with open(vocab_file) as f: lines = f.read().splitlines() - return [l.strip() for l in lines] + return [l.strip() for l in lines] # noqa: E741 class RnaBertTokenizer(PreTrainedTokenizer): @@ -131,4 +130,4 @@ def save_vocabulary(self, save_directory, filename_prefix): @property def vocab_size(self) -> int: - return len(self.all_tokens) \ No newline at end of file + return len(self.all_tokens) diff --git a/setup.py b/setup.py new file mode 100755 index 0000000..6068493 --- /dev/null +++ b/setup.py @@ -0,0 +1,3 @@ +from setuptools import setup + +setup() diff --git a/tox.ini b/tox.ini new file mode 100644 index 0000000..fb02b3f --- /dev/null +++ b/tox.ini @@ -0,0 +1,7 @@ +[flake8] +max-line-length = 120 + +[pycodestyle] +count = True +statistics = True +max-line-length = 120