diff --git a/cs_storage/__init__.py b/cs_storage/__init__.py index 6684515..9bd73a8 100644 --- a/cs_storage/__init__.py +++ b/cs_storage/__init__.py @@ -1,3 +1,4 @@ +import base64 import io import json import os @@ -17,13 +18,16 @@ class Serializer: + """ + Base class for serializng input data to bytes and back. + """ def __init__(self, ext): self.ext = ext def serialize(self, data): return data - def deserialize(self, data): + def deserialize(self, data, json_serializable=True): return data @@ -31,7 +35,7 @@ class JSONSerializer(Serializer): def serialize(self, data): return json.dumps(data).encode() - def deserialize(self, data): + def deserialize(self, data, json_serializable=True): return json.loads(data.decode()) @@ -39,21 +43,32 @@ class TextSerializer(Serializer): def serialize(self, data): return data.encode() - def deserialize(self, data): + def deserialize(self, data, json_serializable=True): return data.decode() +class Base64Serializer(Serializer): + def deserialize(self, data, json_serializable=True): + if json_serializable: + return base64.b64encode(data).decode("utf-8") + else: + return data + + def from_string(self, data): + return base64.b64decode(data.encode("utf-8")) + + def get_serializer(media_type): return { "bokeh": JSONSerializer("json"), "table": TextSerializer("html"), "CSV": TextSerializer("csv"), - "PNG": Serializer("png"), - "JPEG": Serializer("jpeg"), - "MP3": Serializer("mp3"), - "MP4": Serializer("mp4"), - "HDF5": Serializer("h5"), - "PDF": Serializer("pdf"), + "PNG": Base64Serializer("png"), + "JPEG": Base64Serializer("jpeg"), + "MP3": Base64Serializer("mp3"), + "MP4": Base64Serializer("mp4"), + "HDF5": Base64Serializer("h5"), + "PDF": Base64Serializer("pdf"), "Markdown": TextSerializer("md"), "Text": TextSerializer("txt"), }[media_type] @@ -130,7 +145,7 @@ def write(task_id, loc_result, do_upload=True): return rem_result -def read(rem_result): +def read(rem_result, json_serializable=True): # compute studio results have public read access. fs = gcsfs.GCSFileSystem(token="anon") s = time.time() @@ -145,7 +160,7 @@ def read(rem_result): for rem_output in rem_result[category]["outputs"]: ser = get_serializer(rem_output["media_type"]) - rem_data = ser.deserialize(zipfileobj.read(rem_output["filename"])) + rem_data = ser.deserialize(zipfileobj.read(rem_output["filename"]), json_serializable) read[category].append( { "title": rem_output["title"], diff --git a/cs_storage/tests/test_cs_storage.py b/cs_storage/tests/test_cs_storage.py index 51500f3..cc95ade 100644 --- a/cs_storage/tests/test_cs_storage.py +++ b/cs_storage/tests/test_cs_storage.py @@ -1,3 +1,4 @@ +import base64 import io import json import uuid @@ -10,6 +11,44 @@ import cs_storage +@pytest.fixture +def png(): + import matplotlib.pyplot as plt + import numpy as np + x = np.linspace(0, 2, 100) + plt.figure() + plt.plot(x, x, label='linear') + plt.plot(x, x**2, label='quadratic') + plt.plot(x, x**3, label='cubic') + plt.xlabel('x label') + plt.ylabel('y label') + plt.title("Simple Plot") + plt.legend() + initial_buff = io.BytesIO() + plt.savefig(initial_buff, format="png") + initial_buff.seek(0) + return initial_buff.read() + + +@pytest.fixture +def jpg(): + import matplotlib.pyplot as plt + import numpy as np + x = np.linspace(0, 2, 100) + plt.figure() + plt.plot(x, x, label='linear') + plt.plot(x, x**2, label='quadratic') + plt.plot(x, x**3, label='cubic') + plt.xlabel('x label') + plt.ylabel('y label') + plt.title("Simple Plot") + plt.legend() + initial_buff = io.BytesIO() + plt.savefig(initial_buff, format="jpg") + initial_buff.seek(0) + return initial_buff.read() + + def test_JSONSerializer(): ser = cs_storage.JSONSerializer("json") @@ -46,13 +85,28 @@ def test_serializer(): assert act == b"hello world" +def test_base64serializer(png, jpg): + """Test round trip serializtion/deserialization of PNG and JPG""" + ser = cs_storage.Base64Serializer("PNG") + asbytes = ser.serialize(png) + asstr = ser.deserialize(asbytes) + assert png == ser.from_string(asstr) + assert json.dumps({"pic": asstr}) + + ser = cs_storage.Base64Serializer("JPG") + asbytes = ser.serialize(jpg) + asstr = ser.deserialize(asbytes) + assert jpg == ser.from_string(asstr) + assert json.dumps({"pic": asstr}) + + def test_get_serializer(): types = ["bokeh", "table", "CSV", "PNG", "JPEG", "MP3", "MP4", "HDF5"] for t in types: assert cs_storage.get_serializer(t) -def test_cs_storage(): +def test_cs_storage(png, jpg): exp_loc_res = { "renderable": [ { @@ -68,12 +122,12 @@ def test_cs_storage(): { "media_type": "PNG", "title": "PNG data", - "data": b"PNG bytes", + "data": png, }, { "media_type": "JPEG", "title": "JPEG data", - "data": b"JPEG bytes", + "data": jpg, }, { "media_type": "MP3", @@ -117,11 +171,17 @@ def test_cs_storage(): } task_id = uuid.uuid4() rem_res = cs_storage.write(task_id, exp_loc_res) - loc_res = cs_storage.read(rem_res) + loc_res = cs_storage.read(rem_res, json_serializable=False) assert loc_res == exp_loc_res + assert json.dumps( + cs_storage.read(rem_res, json_serializable=True) + ) - loc_res1 = cs_storage.read({"renderable": rem_res["renderable"]}) + loc_res1 = cs_storage.read({"renderable": rem_res["renderable"]}, json_serializable=False) assert loc_res1["renderable"] == exp_loc_res["renderable"] + assert json.dumps( + cs_storage.read({"renderable": rem_res["renderable"]}, json_serializable=True) + ) def test_errors(): diff --git a/environment.yml b/environment.yml index 4abf0a1..0d95114 100644 --- a/environment.yml +++ b/environment.yml @@ -5,3 +5,5 @@ dependencies: - "marshmallow>=3.0.0" - pytest - gcsfs + - matplotlib + - numpy