Skip to content

Commit

Permalink
Bump mlflow max version (#1629)
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg authored Nov 1, 2024
1 parent dd77e86 commit 92252ce
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 11 deletions.
2 changes: 2 additions & 0 deletions llmfoundry/callbacks/hf_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,7 @@ def run_event(self, event: Event, state: State, logger: Logger) -> None:
)

import mlflow
import mlflow.environment_variables
mlflow.environment_variables.MLFLOW_HUGGINGFACE_MODEL_MAX_SHARD_SIZE.set(
'1GB',
)
Expand Down Expand Up @@ -694,6 +695,7 @@ def tensor_hook(

# TODO: Remove after mlflow fixes the bug that makes this necessary
import mlflow
import mlflow.store
mlflow.store._unity_catalog.registry.rest_store.get_feature_dependencies = lambda *args, **kwargs: ''
model_saving_kwargs: dict[str, Any] = {
'path': local_save_path,
Expand Down
26 changes: 16 additions & 10 deletions llmfoundry/utils/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,12 @@
import mlflow
from composer.loggers import Logger
from composer.utils import dist, parse_uri
from mlflow.data import (
delta_dataset_source,
http_dataset_source,
huggingface_dataset_source,
uc_volume_dataset_source,
)
from omegaconf import MISSING, DictConfig, ListConfig, MissingMandatoryValue
from omegaconf import OmegaConf as om
from transformers import PretrainedConfig
Expand Down Expand Up @@ -769,15 +775,15 @@ def log_dataset_uri(cfg: dict[str, Any]) -> None:
data_paths = _parse_source_dataset(cfg)

dataset_source_mapping = {
's3': mlflow.data.http_dataset_source.HTTPDatasetSource,
'oci': mlflow.data.http_dataset_source.HTTPDatasetSource,
'azure': mlflow.data.http_dataset_source.HTTPDatasetSource,
'gs': mlflow.data.http_dataset_source.HTTPDatasetSource,
'https': mlflow.data.http_dataset_source.HTTPDatasetSource,
'hf': mlflow.data.huggingface_dataset_source.HuggingFaceDatasetSource,
'delta_table': mlflow.data.delta_dataset_source.DeltaDatasetSource,
'uc_volume': mlflow.data.uc_volume_dataset_source.UCVolumeDatasetSource,
'local': mlflow.data.http_dataset_source.HTTPDatasetSource,
's3': http_dataset_source.HTTPDatasetSource,
'oci': http_dataset_source.HTTPDatasetSource,
'azure': http_dataset_source.HTTPDatasetSource,
'gs': http_dataset_source.HTTPDatasetSource,
'https': http_dataset_source.HTTPDatasetSource,
'hf': huggingface_dataset_source.HuggingFaceDatasetSource,
'delta_table': delta_dataset_source.DeltaDatasetSource,
'uc_volume': uc_volume_dataset_source.UCVolumeDatasetSource,
'local': http_dataset_source.HTTPDatasetSource,
}

# Map data source types to their respective MLFlow DataSource.
Expand All @@ -795,7 +801,7 @@ def log_dataset_uri(cfg: dict[str, Any]) -> None:
log.info(
f'{dataset_type} unknown, defaulting to http dataset source',
)
source = mlflow.data.http_dataset_source.HTTPDatasetSource(url=path)
source = http_dataset_source.HTTPDatasetSource(url=path)

mlflow.log_input(
mlflow.data.meta_dataset.MetaDataset(source, name=split),
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@

install_requires = [
'mosaicml[libcloud,wandb,oci,gcs,mlflow]>=0.26.0,<0.27',
'mlflow>=2.14.1,<2.17',
'mlflow>=2.14.1,<2.18',
'accelerate>=0.25,<0.34', # for HF inference `device_map`
'transformers>=4.43.2,<4.44',
'mosaicml-streaming>=0.9.0,<0.10',
Expand Down

0 comments on commit 92252ce

Please sign in to comment.