diff --git a/.readthedocs.yaml b/.readthedocs.yaml new file mode 100644 index 0000000..4f5754d --- /dev/null +++ b/.readthedocs.yaml @@ -0,0 +1,25 @@ +# .readthedocs.yml +# Read the Docs configuration file +# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details + +# Required +version: 2 + +build: + os: "ubuntu-20.04" + tools: + python: "3.9" + +# Build documentation in the docs/ directory with Sphinx +sphinx: + configuration: docs/conf.py + fail_on_warning: true + +# Optionally build your docs in additional formats such as PDF and ePub +formats: + - htmlzip + +# Optionally set the version of Python and requirements required to build your docs +python: + install: + - requirements: docs/requirements.txt \ No newline at end of file diff --git a/docs/CheckPointing.md b/docs/CheckPointing.md deleted file mode 100644 index e04e4fe..0000000 --- a/docs/CheckPointing.md +++ /dev/null @@ -1,55 +0,0 @@ -* **StreamingCheckpointer** class: - * **__init__(config, checkpoint_dir, enable=True)**: - Initializes a StreamingCheckpointer object. - Args: - config: A dictionary of configuration options. - checkpoint_dir: The directory where checkpoints will be saved. - enable: Whether to enable the streaming checkpointing functionality. - * **save_checkpoint(train_state, filename, gather_fns=None)**: - Saves a checkpoint to the specified file. - Args: - train_state: The train state to save. - filename: The name of the checkpoint file. - gather_fns: A dictionary of functions that can be used to gather - large tensors into smaller chunks before saving them. - * **save_all(train_state, gather_fns, metadata=None, dataset=None, milestone=False)**: - Saves a checkpoint for the current step, as well as metadata and dataset - information. - Args: - train_state: The train state to save. - gather_fns: A dictionary of functions that can be used to gather - large tensors into smaller chunks before saving them. - metadata: Metadata to save. - dataset: Dataset information to save. - milestone: Whether this is a milestone checkpoint. - * **load_checkpoint(path, target=None, shard_fns=None, remove_dict_prefix=None)**: - Loads a checkpoint from the specified file. - Args: - path: The path to the checkpoint file. - target: The object to load the checkpoint into. - shard_fns: A dictionary of functions that can be used to shard - tensors after loading them. - remove_dict_prefix: A tuple of keys to remove from the loaded - checkpoints. - * **load_flax_checkpoint(path, target=None, shard_fns=None)**: - Loads a standard flax checkpoint from the specified file. - Args: - path: The path to the checkpoint file. - target: The object to load the checkpoint into. - shard_fns: A dictionary of functions that can be used to shard - tensors after loading them. - -* **load_trainstate_checkpoint(load_from, trainstate_target=None, - trainstate_shard_fns=None, - disallow_trainstate=False)**: - Load a train state checkpoint from the specified load_from string. - Args: - load_from: The load_from string, which can be one of the following: - * 'trainstate': Load the entire train state. - * 'trainstate_params': Load the params part of the train state. - * 'params': Load the params. - * 'flax_params': Load the params in the standard flax format (non-streaming). - trainstate_target: The target object to load the train state into. - trainstate_shard_fns: A dictionary of functions that can be used to shard - tensors after loading them. - disallow_trainstate: Whether to disallow loading the full train state. diff --git a/docs/Optimizers.md b/docs/Optimizers.md deleted file mode 100644 index a93bfa1..0000000 --- a/docs/Optimizers.md +++ /dev/null @@ -1,205 +0,0 @@ -* **optax_add_scheduled_weight_decay** function: - * **Arguments:** - * **schedule_fn:** A function that takes the current step number as input and returns the weight decay value. - * **mask:** An optional mask that can be used to apply weight decay to a subset of parameters. - * **Returns:** - An Optax GradientTransformation object that adds the scheduled weight decay to the updates. - -* **get_adamw_with_cosine_scheduler** function: - -```python -tx, scheduler = get_adamw_with_cosine_scheduler(*args, **kwargs) -``` - -* **Arguments:** - * **steps:** The total number of training steps. - * **learning_rate:** The initial learning rate. - * **b1:** The first Adam beta parameter. - * **b2:** The second Adam beta parameter. - * **eps:** The Adam epsilon parameter. - * **eps_root:** The Adam epsilon root parameter. - * **weight_decay:** The weight decay coefficient. - * **gradient_accumulation_steps:** The number of gradient accumulation steps. - * **mu_dtype:** The dtype of the Adam momentum terms. -* **Returns:** - A tuple of the Adam optimizer and the cosine learning rate scheduler. - -* **get_adamw_with_linear_scheduler** function: - -```python -tx, scheduler = get_adamw_with_linear_scheduler(*args, **kwargs) -``` - -* **Arguments:** - * **steps:** The total number of training steps. - * **learning_rate_start:** The initial learning rate. - * **learning_rate_end:** The final learning rate. - * **b1:** The first Adam beta parameter. - * **b2:** The second Adam beta parameter. - * **eps:** The Adam epsilon parameter. - * **eps_root:** The Adam epsilon root parameter. - * **weight_decay:** The weight decay coefficient. - * **gradient_accumulation_steps:** The number of gradient accumulation steps. - * **mu_dtype:** The dtype of the Adam momentum terms. -* **Returns:** - A tuple of the Adam optimizer and the linear learning rate scheduler. - -* **get_adafactor_with_linear_scheduler** function: - -```python -tx, scheduler = get_adafactor_with_linear_scheduler(*args, **kwargs) -``` - -* **Arguments:** - * **steps:** The total number of training steps. - * **learning_rate_start:** The initial learning rate. - * **learning_rate_end:** The final learning rate. - * **weight_decay:** The weight decay coefficient. - * **min_dim_size_to_factor:** The minimum size of a parameter tensor for it to be factored by Adafactor. - * **decay_rate:** The decay rate parameter for Adafactor. - * **decay_offset:** The decay offset parameter for Adafactor. - * **multiply_by_parameter_scale:** Whether to multiply the learning rate by the parameter scale. - * **clipping_threshold:** The gradient clipping threshold. - * **momentum:** The momentum parameter for Adafactor. - * **dtype_momentum:** The dtype of the momentum term for Adafactor. - * **weight_decay_rate:** The weight decay rate for Adafactor. - * **eps:** The epsilon parameter for Adafactor. - * **factored:** Whether to use the factored implementation of Adafactor. - * **gradient_accumulation_steps:** The number of gradient accumulation steps. - * **weight_decay_mask:** A mask tensor that specifies which parameters to apply weight decay to. -* **Returns:** - A tuple of the Adafactor optimizer and the linear learning rate scheduler. - -* **get_adafactor_with_cosine_scheduler** function: - -```python -tx, scheduler = get_adafactor_with_cosine_scheduler(*args, **kwargs) -``` - -* **Arguments:** - * **steps:** The total number of training steps. - * **learning_rate:** The initial learning rate. - * **weight_decay:** The weight decay coefficient. - * **min_dim_size_to_factor:** The minimum size of a parameter tensor for it to be factored by Adafactor. - * **decay_rate:** The decay rate parameter for Adafactor. - * **decay_offset:** The decay offset parameter for Adafactor. - * **multiply_by_parameter_scale:** Whether to multiply the learning rate by the parameter scale. - * **clipping_threshold:** The gradient clipping threshold. - * **momentum:** The momentum parameter for Adafactor. - * **dtype_momentum:** The dtype of the momentum term for Adafactor. - * **weight_decay_rate:** The weight decay rate for Adafactor. - * **eps:** The epsilon parameter for Adafactor. - * **factored:** Whether to use the factored implementation of Adafactor. - * **gradient_accumulation_steps:** The number of gradient accumulation steps. - * **weight_decay_mask:** A mask tensor that specifies which parameters to apply weight decay to. -* **Returns:** - A tuple of the Adafactor optimizer and the cosine learning rate scheduler. - -* **get_lion_with_cosine_scheduler** function: - -```python -tx, scheduler = get_lion_with_cosine_scheduler(*args, **kwargs) -``` - -* **Arguments:** - * **steps:** The total number of training steps. - * **learning_rate:** The initial learning rate. - * **alpha:** The minimum value of the multiplier used to adjust the learning rate. - * **exponent:** The exponent of the cosine decay schedule. - * **b1:** The first Lion beta parameter. - * **b2:** The second Lion beta parameter. - * **gradient_accumulation_steps:** The number of gradient accumulation steps. - * **mu_dtype:** The dtype of the Lion momentum terms. -* **Returns:** - A tuple of the Lion optimizer and the cosine learning rate scheduler. - -* **get_lion_with_linear_scheduler** function: - -```python -tx, scheduler = get_lion_with_linear_scheduler(*args, **kwargs) -``` - -* **Arguments:** - * **steps:** The total number of training steps. - * **learning_rate_start:** The initial learning rate. - * **learning_rate_end:** The final learning rate. - * **b1:** The first Lion beta parameter. - * **b2:** The second Lion beta parameter. - * **gradient_accumulation_steps:** The number of gradient accumulation steps. - * **mu_dtype:** The dtype of the Lion momentum terms. -* **Returns:** - A tuple of the Lion optimizer and the linear learning rate scheduler. - -[//]: # (* **get_adamw_with_warm_up_cosine_scheduler** function:) - -* **get_lion_with_linear_scheduler** function: - -```python -tx, scheduler = get_lion_with_linear_scheduler(*args, **kwargs) -``` - -* **Arguments:** - * **steps:** The total number of training steps. - * **learning_rate:** The initial learning rate. - * **b1:** The first Adam beta parameter. - * **b2:** The second Adam beta parameter. - * **eps:** The Adam epsilon parameter. - * **eps_root:** The Adam epsilon root parameter. - * **weight_decay:** The weight decay coefficient. - * **exponent:** The exponent of the cosine decay schedule. - * **gradient_accumulation_steps:** The number of gradient accumulation steps. - * **mu_dtype:** The dtype of the Adam momentum terms. -* **Returns:** - A tuple of the Adam optimizer and the cosine learning rate scheduler. - -* **get_adafactor_with_warm_up_cosine_scheduler** function: - -```python -tx, scheduler = get_adafactor_with_warm_up_cosine_scheduler(*args, **kwargs) -``` - -* **Arguments:** - * **steps:** The total number of training steps. - * **learning_rate:** The initial learning rate. - * **weight_decay:** The weight decay coefficient. - * **min_dim_size_to_factor:** The minimum size of a parameter tensor for it to be factored by Adafactor. - * **decay_rate:** The decay rate parameter for Adafactor. - * **decay_offset:** The decay offset parameter for Adafactor. - * **multiply_by_parameter_scale:** Whether to multiply the learning rate by the parameter scale. - * **clipping_threshold:** The gradient clipping threshold. - * **momentum:** The momentum parameter for Adafactor. - * **dtype_momentum:** The dtype of the momentum term for Adafactor. - * **weight_decay_rate:** The weight decay rate for Adafactor. - * **eps:** The epsilon parameter for Adafactor. - * **factored:** Whether to use the factored implementation of Adafactor. - * **exponent:** The exponent of the cosine decay schedule. - * **weight_decay_mask:** A mask tensor that specifies which parameters to apply weight decay to. - * **gradient_accumulation_steps:** The number of gradient accumulation steps. -* **Returns:** - A tuple of the Adafactor optimizer and the cosine learning rate scheduler. - -* **get_lion_with_warm_up_cosine_scheduler** function: - -```python -tx, scheduler = get_lion_with_warm_up_cosine_scheduler(*args, **kwargs) -``` - -* **Arguments:** - * **steps:** The total number of training steps. - * **learning_rate:** The initial learning rate. - * **exponent:** The exponent of the cosine decay schedule. - * **b1:** The first Lion beta parameter. - * **b2:** The second Lion beta parameter. - * **gradient_accumulation_steps:** The number of gradient accumulation steps. - * **mu_dtype:** The dtype of the Lion momentum terms. -* **Returns:** - A tuple of the Lion optimizer and the cosine learning rate scheduler. - -The references for these functions are: - -* Lion: A Linear-Complexity Adaptive Learning Rate Method: https://arxiv.org/abs/2204.02267 -* Adam: A Method for Stochastic Optimization: https://arxiv.org/abs/1412.6980 -* Cosine Annealing with Restarts for Stochastic Optimization: https://arxiv.org/abs/1608.03983 -* Adafactor: Adaptive Learning Rates for Neural Networks: https://arxiv.org/abs/1804.04235 - -I hope this documentation is helpful. Let me know if you have any other questions. \ No newline at end of file diff --git a/docs/README.md b/docs/README.md deleted file mode 100644 index f1cb62a..0000000 --- a/docs/README.md +++ /dev/null @@ -1,5 +0,0 @@ -## Here are some utils that you might be interested to use - -1. [Optimizer](https://erfanzar.github.io/fjformer/docs/Optimizers) -2. [Check Pointing Utils](https://erfanzar.github.io/fjformer/docs/CheckPointing) -3. [Utils](https://erfanzar.github.io/fjformer/docs/Utils) \ No newline at end of file diff --git a/docs/Utils.md b/docs/Utils.md deleted file mode 100644 index 9959cc1..0000000 --- a/docs/Utils.md +++ /dev/null @@ -1,280 +0,0 @@ -* **is_torch_available()** function: - * **Returns:** - A boolean value indicating whether the PyTorch library is installed. - -* **match_partition_rules(rules, params)** function: - * **Arguments:** - * **rules:** A list of tuples, where each tuple consists of a regular expression and a sharding spec. - * **params:** A flax.core.FrozenParams object. - * **Returns:** - A jax.tree_map object that applies the matching sharding spec to each parameter in the model. - -* **count_num_params(_p)** function: - * **Arguments:** - * **_p:** A flax.core.FrozenParams object. - * **Returns:** - The number of parameters in the model. - -* **count_params(_p)** function: - * **Arguments:** - * **_p:** A flax.core.FrozenParams object. - * **Returns:** - The number of parameters in the model in billions. - -* **names_in_mesh(*names)** function: - * **Arguments:** - * **names:** A list of strings. - * **Returns:** - A boolean value indicating whether all the names are in the current TPU mesh. - -* **get_names(partition_specs)** function: - * **Arguments:** - * **partition_specs:** A list of sharding specs. - * **Returns:** - A list of strings, where each string is the name of a sharding spec in the list. - -* **with_sharding_constraint__a(x, partition_spec)** function: - * **Arguments:** - * **x:** A Jax array. - * **partition_spec:** A sharding spec. - * **Returns:** - A Jax array with the same data as `x`, but with the sharding spec set to `partition_spec`. - -* **get_devices(tensor)** function: - * **Arguments:** - * **tensor:** A Jax array. - * **Returns:** - A list of strings, where each string is the device name of a device in the array. - - -* **change_to_bf16(tensor)** function: - * **Arguments:** - * **tensor:** A Jax array. - * **Returns:** - A Jax array with the same data as `tensor`, but with the dtype set to `jnp.bfloat16`. - -* **change_to_fp16(tensor)** function: - * **Arguments:** - * **tensor:** A Jax array. - * **Returns:** - A Jax array with the same data as `tensor`, but with the dtype set to `jnp.float16`. - -* **change_to_fp32(tensor)** function: - * **Arguments:** - * **tensor:** A Jax array. - * **Returns:** - A Jax array with the same data as `tensor`, but with the dtype set to `jnp.float32`. - -* **change(tensor, device)** function: - * **Arguments:** - * **tensor:** A Jax array. - * **device:** A string, the name of the device to put the tensor on. - * **Returns:** - A Jax array with the same data as `tensor`, but on the specified device. - -* **read_ckpt(path: [str, os.PathLike], shard_fns=None, add_extra_past_fix: list = None)** function: - * **Arguments:** - * **path:** The path to the checkpoint file. - * **shard_fns:** A dictionary of functions that map from tensor names to functions that shard the tensors. - * **add_extra_past_fix:** A list of strings that should be prepended to all tensor names in the checkpoint file. - * **Returns:** - A dictionary of tensors, where the keys are the tensor names and the values are the tensors from the checkpoint - file. - -* **save_ckpt(train_state, path, gather_fns=None, float_dtype=None)** function: - * **Arguments:** - * **train_state:** The model state to save. - * **path:** The path to the checkpoint file. - * **gather_fns:** A dictionary of functions that map from tensor names to functions that gather the tensors. - * **float_dtype:** The floating point dtype to use for the tensors in the checkpoint file. - * **Returns:** - None. - - -* **match_keywords(string, ts, ns)** function: - * **Arguments:** - * **string:** A string. - * **ts:** A list of strings, the keywords that should be present in the string. - * **ns:** A list of strings, the keywords that should not be present in the string. - * **Returns:** - A boolean value indicating whether the string contains all of the keywords in `ts` and none of the keywords - in `ns`. - -* **load_and_convert_checkpoint(path, dtype=jnp.float16, transpose_needed: List[str] = ["kernel"], - transpose_not_needed: List[str] = ['none'], select_params_field: bool = True)** function: - * **Arguments:** - * **path:** The path to the checkpoint file. - * **dtype:** The floating point dtype to use for the tensors in the checkpoint file. - * **transpose_needed:** A list of strings, the names of the tensors that need to be transposed. - * **transpose_not_needed:** A list of strings, the names of the tensors that do not need to be transposed. - * **select_params_field:** A boolean value indicating whether to only load the `params` field from the - checkpoint file. - * **Returns:** - A dictionary of tensors, where the keys are the tensor names and the values are the tensors from the checkpoint - file, converted to the specified dtype and transposed if necessary. - -* **read_json(path)** function: - * **Arguments:** - * **path:** The path to the JSON file. - * **Returns:** - The contents of the JSON file as a dictionary. - -* **write_json(text, path)** function: - * **Arguments:** - * **text:** The text to write to the JSON file. - * **path:** The path to the JSON file. - * **Returns:** - None. - - -* **get_dataloader** function: - -```python -def get_dataloader(dataset_or_huggingface_dataset_hub_id: Any, batch_size: int, num_epochs: int, - select_hf_dataset_field='train', - max_steps: int = None, max_length: int = 4096, dataset_hf_kwargs: dict = {}, - collate_fn: Callable = None, shuffle: Optional[bool] = None, - sampler=None, - batch_sampler=None, - num_workers: int = 0, - pin_memory: bool = False, drop_last: bool = False, - timeout: float = 0, worker_init_fn=None, - multiprocessing_context=None, generator=None, - *, prefetch_factor: Optional[int] = None, - persistent_workers: bool = False, - pin_memory_device: str = ""): -``` - -**Documentation reference:** - -* `dataset_or_huggingface_dataset_hub_id`: The dataset to load. This can be either a string, which is the name of a - dataset from the Huggingface Datasets Hub, or a custom dataset object. -* `batch_size`: The batch size. -* `num_epochs`: The number of epochs to train for. -* `select_hf_dataset_field`: The field of the Huggingface Dataset to use, such as `train` or `validation`. -* `max_steps`: The maximum number of steps to train for. If `None`, the training will run for `num_epochs` * len( - dataloader). -* `max_length`: The maximum length of a sequence in the dataloader. -* `dataset_hf_kwargs`: Keyword arguments to pass to the Huggingface Dataset loader. -* `collate_fn`: A function to collate the data into batches. If `None`, a default collate function will be used. -* `shuffle`: Whether to shuffle the data. -* `sampler`: A sampler to use for selecting data batches. -* `batch_sampler`: A batch sampler to use for selecting data batches. -* `num_workers`: The number of worker processes to use for data loading. -* `pin_memory`: Whether to pin data to the GPU memory. -* `drop_last`: Whether to drop the last batch if it is not full. -* `timeout`: The timeout for each worker process. -* `worker_init_fn`: A function to be called on each worker process. -* `multiprocessing_context`: The multiprocessing context to use. -* `generator`: A generator to use for yielding data batches. -* `prefetch_factor`: The number of batches to prefetch. -* `persistent_workers`: Whether to keep the worker processes alive after the dataloader is exhausted. -* `pin_memory_device`: The device to pin data to. - -Here is a usage example of the `match_partition_rules()` function: - -```python -import jax -import flax.core -from jax.sharding import PartitionSpec as PS - -# Define a list of rules and a model. -rules = [('.*embedding.*', PS()), ('.*kernel.*', PS())] -model = flax.core.unfreeze(flax.training.train_state.params) - -# Apply the matching sharding spec to each parameter in the model. -partitioned_params = match_partition_rules(rules, model) -``` - -The `count_num_params()` and `count_params()` functions can be used to count the number of parameters in a model. For -example: - -```python -num_params = count_num_params(model) -print('The model has {} parameters.'.format(num_params)) -``` - -The `names_in_mesh()` function can be used to check whether a set of names are in the current TPU mesh. For example: - -```python -names = ['embedding', 'kernel'] -if names_in_mesh(*names): - print('All of the names are in the current TPU mesh.') -else: - print('Some of the names are not in the current TPU mesh.') -``` - -Here is a usage example of the `match_keywords()` function: - -```python -import torch - -# Define a string. -string = 'kernel' - -# Check if the string contains the keyword `kernel`. -assert match_keywords(string, ['kernel']) == True - -# Check if the string contains the keyword `bias`. -assert match_keywords(string, ['bias']) == False -``` - -The `load_and_convert_checkpoint()` function can be used to load a checkpoint file from Flax and convert it to a Torch -checkpoint file. For example: - -```python -import torch -import jax - -# Define the path to the checkpoint file. -path = 'checkpoint.ckpt' - -# Load the checkpoint file from Flax. -flax_params = load_and_convert_checkpoint(path) - -# Convert the Flax parameters to Torch parameters. -torch_params = {} -for key, tensor in flax_params.items(): - torch_params[key] = torch.from_numpy(tensor) - -# Save the Torch parameters to a file. -torch.save(torch_params, 'torch_checkpoint.pth') -``` - -Here is a usage example of the `change_to_bf16()` function: - -```python -import jax - -# Define a float32 array. -array = jnp.ones((10, 10), dtype=jnp.float32) - -# Convert the array to bfloat16. -bfloat16_array = change_to_bf16(array) - -# Check the dtype of the array. -assert bfloat16_array.dtype == jnp.bfloat16 -``` - -The `read_ckpt()` and `save_ckpt()` functions can be used to load and save model checkpoints. For example: - -```python -import jax - -# Define the path to the checkpoint file. -path = 'checkpoint.ckpt' - -# Load the model state from the checkpoint file. -train_state = read_ckpt(path) - -# Save the model state to the checkpoint file. -save_ckpt(train_state, path) -``` - -The references for these functions are: - -* Flax: https://flax.readthedocs.io/en/latest/ -* JAX: https://jax.readthedocs.io/en/latest/ -* Torch: https://pytorch.org/ -* XLA: https://www.tensorflow.org/xla -* msgpack: https://msgpack.org/ \ No newline at end of file diff --git a/docs/_static/style.css b/docs/_static/style.css new file mode 100644 index 0000000..4e33f04 --- /dev/null +++ b/docs/_static/style.css @@ -0,0 +1,29 @@ +@import url("theme.css"); + +:root { + --block-bg-opacity: .5; +} + +.wy-side-nav-search { + background-color: #fff; +} + +.getting-started { + background-color: rgba(78, 150, 253, var(--block-bg-opacity)); +} + +.user-guides { + background-color: rgba(0, 169, 154, var(--block-bg-opacity)); +} + +.developer-docs { + background-color: rgba(171, 0, 182, var(--block-bg-opacity)); +} + +div.red-background pre { + background-color: rgba(244, 204, 204, var(--block-bg-opacity)); +} + +div.green-background pre { + background-color: rgba(204, 244, 204, var(--block-bg-opacity)); +} \ No newline at end of file diff --git a/docs/_templates/layout.html b/docs/_templates/layout.html new file mode 100644 index 0000000..93a7941 --- /dev/null +++ b/docs/_templates/layout.html @@ -0,0 +1,2 @@ +{% extends "!layout.html" %} +{% set css_files = css_files + ["_static/style.css"] %} \ No newline at end of file diff --git a/docs/conf.py b/docs/conf.py new file mode 100644 index 0000000..46e6825 --- /dev/null +++ b/docs/conf.py @@ -0,0 +1,170 @@ +# Copyright 2018 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# -*- coding: utf-8 -*- +# +# Configuration file for the Sphinx documentation builder. +# +# This file does only contain a selection of the most common options. For a +# full list see the documentation: +# http://www.sphinx-doc.org/en/master/config + +# -- Path setup -------------------------------------------------------------- + +# If extensions (or modules to document with autodoc) are in another directory, +# add these directories to sys.path here. If the directory is relative to the +# documentation root, use os.path.abspath to make it absolute, like shown here. +# +import os +import sys + +sys.path.insert(0, os.path.abspath('..')) + +project = 'FJFormer' +copyright = '2023, The FJFormer. NumPy, JAX and SciPy documentation are copyright the respective authors.' +author = 'erfan zare chavoshi' + +# The short X.Y version +version = '' +# The full version, including alpha/beta/rc tags +release = '' + +# -- General configuration --------------------------------------------------- + +# If your documentation needs a minimal Sphinx version, state it here. +# +needs_sphinx = '2.1' + +# Add any Sphinx extension module names here, as strings. They can be +# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom +# ones. +sys.path.append(os.path.abspath('sphinxext')) +extensions = [ + 'sphinx.ext.autodoc', + 'sphinx.ext.autosummary', + 'sphinx.ext.intersphinx', + 'sphinx.ext.mathjax', + 'sphinx.ext.napoleon', + 'sphinx.ext.viewcode', + 'matplotlib.sphinxext.plot_directive', + 'sphinx_autodoc_typehints', + 'myst_nb', + "sphinx_remove_toctrees", + 'sphinx_copybutton', + 'jax_extensions', + 'sphinx_design' +] + +intersphinx_mapping = { + 'python': ('https://docs.python.org/3/', None), + 'numpy': ('https://numpy.org/doc/stable/', None), + 'scipy': ('https://docs.scipy.org/doc/scipy/reference/', None), +} + +suppress_warnings = [ + 'ref.citation', # Many duplicated citations in numpy/scipy docstrings. + 'ref.footnote', # Many unreferenced footnotes in numpy/scipy docstrings + 'myst.header', + # TODO(jakevdp): remove this suppression once issue is fixed. + 'misc.highlighting_failure', # https://github.com/ipython/ipython/issues/14142 +] + +# Add any paths that contain templates here, relative to this directory. +templates_path = ['_templates'] + +# The suffix(es) of source filenames. +# Note: important to list ipynb before md here: we have both md and ipynb +# copies of each notebook, and myst will choose which to convert based on +# the order in the source_suffix list. Notebooks which are not executed have +# outputs stored in ipynb but not in md, so we must convert the ipynb. +source_suffix = ['.rst', '.ipynb', '.md'] + +# The main toctree document. +main_doc = 'index' + +# The language for content autogenerated by Sphinx. Refer to documentation +# for a list of supported languages. +# +# This is also used if you do content translation via gettext catalogs. +# Usually you set "language" from the command line for these cases. +language = 'en' + +# List of patterns, relative to source directory, that match files and +# directories to ignore when looking for source files. +# This pattern also affects html_static_path and html_extra_path. +exclude_patterns = [ + '*.md' +] + +# The name of the Pygments (syntax highlighting) style to use. +pygments_style = None + +autosummary_generate = True +napolean_use_rtype = False + +html_theme = 'sphinx_book_theme' + +html_theme_options = { + 'show_toc_level': 2, + 'repository_url': 'https://github.com/erfanzar/fjformer', + 'use_repository_button': True, # add a "link to repository" button + 'navigation_with_keys': False, +} + +html_logo = 'light-logo.png' +html_static_path = ['_static'] + +html_css_files = [ + 'style.css', +] + +myst_heading_anchors = 3 # auto-generate 3 levels of heading anchors +myst_enable_extensions = ['dollarmath'] +nb_execution_mode = "force" +nb_execution_allow_errors = False +nb_merge_streams = True + +nb_execution_timeout = 100 + +htmlhelp_basename = 'JAXdoc' + +latex_elements = { +} + +latex_documents = [ + (main_doc, 'JAX.tex', 'JAX Documentation', + 'The JAX authors', 'manual'), +] + +man_pages = [ + (main_doc, 'FJFormer', 'FJFormer Documentation', + [author], 1) +] + +texinfo_documents = [ + (main_doc, 'FJFormer', 'FJFormer Documentation', + author, 'FJFormer', 'One line description of project.', + 'Miscellaneous'), +] +epub_title = project +epub_exclude_files = ['search.html'] + +always_document_param_types = True + +autodoc_type_aliases = { + 'ArrayLike': 'ArrayLike', + 'DTypeLike': 'DTypeLike', +} + +remove_from_toctrees = ["_autosummary/*"] diff --git a/docs/dark-logo.png b/docs/dark-logo.png new file mode 100644 index 0000000..b682043 Binary files /dev/null and b/docs/dark-logo.png differ diff --git a/docs/light-logo.png b/docs/light-logo.png new file mode 100644 index 0000000..e3aadeb Binary files /dev/null and b/docs/light-logo.png differ