diff --git a/llmfoundry/models/inference_api_wrapper/openai_causal_lm.py b/llmfoundry/models/inference_api_wrapper/openai_causal_lm.py index 7b17ccf43c..7abc28d9fe 100644 --- a/llmfoundry/models/inference_api_wrapper/openai_causal_lm.py +++ b/llmfoundry/models/inference_api_wrapper/openai_causal_lm.py @@ -43,9 +43,7 @@ def __init__(self, om_model_config: DictConfig, tokenizer: AutoTokenizer, api_ke conda_package='openai', conda_channel='conda-forge') from e if api_key is None: - api_key = os.environ.get('OPENAI_API_KEY') - - api_key = os.environ.get('OPENAI_API_KEY') + api_key = os.environ.get(om_model_config.get('api_env_key', 'OPENAI_API_KEY')) base_url = om_model_config.get('base_url') if base_url is None: # Using OpenAI default, where the API key is required @@ -53,12 +51,6 @@ def __init__(self, om_model_config: DictConfig, tokenizer: AutoTokenizer, api_ke raise ValueError( 'No OpenAI API Key found. Ensure it is saved as an environmental variable called OPENAI_API_KEY.' ) - else: - # Using a custom base URL, where the API key may not be required - log.info( - f'Making request to custom base URL: {base_url}{"" if api_key is not None else " (no API key set)"}' - ) - api_key = 'placeholder' # This cannot be None self.client = openai.OpenAI(base_url=base_url, api_key=api_key) if 'version' in om_model_config: