Skip to content

Commit

Permalink
Fix handling of PNG/JPG data
Browse files Browse the repository at this point in the history
  • Loading branch information
hdoupe committed Sep 25, 2019
1 parent 3e0a396 commit 979474d
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 6 deletions.
21 changes: 15 additions & 6 deletions cs_storage/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import base64
import io
import json
import os
Expand Down Expand Up @@ -43,17 +44,25 @@ def deserialize(self, data):
return data.decode()


class Base64Serializer(Serializer):
def deserialize(self, data):
return base64.b64encode(data).decode("utf-8")

def from_string(self, data):
return base64.b64decode(data.encode("utf-8"))


def get_serializer(media_type):
return {
"bokeh": JSONSerializer("json"),
"table": TextSerializer("html"),
"CSV": TextSerializer("csv"),
"PNG": Serializer("png"),
"JPEG": Serializer("jpeg"),
"MP3": Serializer("mp3"),
"MP4": Serializer("mp4"),
"HDF5": Serializer("h5"),
"PDF": Serializer("pdf"),
"PNG": Base64Serializer("png"),
"JPEG": Base64Serializer("jpeg"),
"MP3": Base64Serializer("mp3"),
"MP4": Base64Serializer("mp4"),
"HDF5": Base64Serializer("h5"),
"PDF": Base64Serializer("pdf"),
"Markdown": TextSerializer("md"),
"Text": TextSerializer("txt"),
}[media_type]
Expand Down
54 changes: 54 additions & 0 deletions cs_storage/tests/test_cs_storage.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import base64
import io
import json
import uuid
Expand All @@ -10,6 +11,44 @@
import cs_storage


@pytest.fixture
def png():
import matplotlib.pyplot as plt
import numpy as np
x = np.linspace(0, 2, 100)
plt.figure()
plt.plot(x, x, label='linear')
plt.plot(x, x**2, label='quadratic')
plt.plot(x, x**3, label='cubic')
plt.xlabel('x label')
plt.ylabel('y label')
plt.title("Simple Plot")
plt.legend()
initial_buff = io.BytesIO()
plt.savefig(initial_buff, format="png")
initial_buff.seek(0)
return initial_buff.read()


@pytest.fixture
def jpg():
import matplotlib.pyplot as plt
import numpy as np
x = np.linspace(0, 2, 100)
plt.figure()
plt.plot(x, x, label='linear')
plt.plot(x, x**2, label='quadratic')
plt.plot(x, x**3, label='cubic')
plt.xlabel('x label')
plt.ylabel('y label')
plt.title("Simple Plot")
plt.legend()
initial_buff = io.BytesIO()
plt.savefig(initial_buff, format="jpg")
initial_buff.seek(0)
return initial_buff.read()


def test_JSONSerializer():
ser = cs_storage.JSONSerializer("json")

Expand Down Expand Up @@ -46,6 +85,21 @@ def test_serializer():
assert act == b"hello world"


def test_base64serializer(png, jpg):
"""Test round trip serializtion/deserialization of PNG and JPG"""
ser = cs_storage.Base64Serializer("PNG")
asbytes = ser.serialize(png)
asstr = ser.deserialize(asbytes)
assert png == ser.from_string(asstr)
assert json.dumps({"pic": asstr})

ser = cs_storage.Base64Serializer("JPG")
asbytes = ser.serialize(jpg)
asstr = ser.deserialize(asbytes)
assert jpg == ser.from_string(asstr)
assert json.dumps({"pic": asstr})


def test_get_serializer():
types = ["bokeh", "table", "CSV", "PNG", "JPEG", "MP3", "MP4", "HDF5"]
for t in types:
Expand Down
2 changes: 2 additions & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,5 @@ dependencies:
- "marshmallow>=3.0.0"
- pytest
- gcsfs
- matplotlib
- numpy

0 comments on commit 979474d

Please sign in to comment.