forked from GoogleCloudPlatform/cloudml-samples
-
Notifications
You must be signed in to change notification settings - Fork 0
/
inference.py
37 lines (26 loc) · 1.13 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
from googleapiclient import discovery
from oauth2client.client import GoogleCredentials
GCP_PROJECT = '' # Change to GCP project where the AI Platform model is deployed
CMLE_MODEL_NAME = '' # Change to the deployed AI Platform model
CMLE_MODEL_VERSION = None # If None, the default version will be used
def predict_cmle(instances):
""" Use a deployed model to AI Platform to perform prediction
Args:
instances: list of json, csv, or tf.example objects, based on the serving function called
Returns:
response - dictionary. If no error, response will include an item with 'predictions' key
"""
credentials = GoogleCredentials.get_application_default()
service = discovery.build('ml', 'v1', credentials=credentials)
model_url = 'projects/{}/models/{}'.format(GCP_PROJECT, CMLE_MODEL_NAME)
if CMLE_MODEL_VERSION is not None:
model_url += '/versions/{}'.format(CMLE_MODEL_VERSION)
request_data = {
'instances': instances
}
response = service.projects().predict(
body=request_data,
name=model_url
).execute()
output = response
return output