diff --git a/badcrossbar/plot.py b/badcrossbar/plot.py index 6021402..2f824d2 100644 --- a/badcrossbar/plot.py +++ b/badcrossbar/plot.py @@ -67,20 +67,7 @@ def branches(device_vals=None, word_line_vals=None, width : float, optional Width of the diagram in millimeters. """ - kwargs.setdefault('default_color', (0, 0, 0)) - kwargs.setdefault('wire_scaling_factor', 1) - kwargs.setdefault('device_scaling_factor', 1) - kwargs.setdefault('node_scaling_factor', 1) - kwargs.setdefault('axis_label', 'Current (A)') - kwargs.setdefault('low_rgb', (213/255, 94/255, 0/255)) - kwargs.setdefault('zero_rgb', (235/255, 235/255, 235/255)) - kwargs.setdefault('high_rgb', (0/255, 114/255, 178/255)) - kwargs.setdefault('allow_overwrite', False) - kwargs.setdefault('filename', 'crossbar-currents') - kwargs.setdefault('device_type', 'memristor') - kwargs.setdefault('significant_figures', 2) - kwargs.setdefault('round_crossings', True) - kwargs.setdefault('width', 210) + kwargs = plotting.utils.set_defaults(kwargs, True) if currents is not None: device_vals = currents.device @@ -99,11 +86,10 @@ def branches(device_vals=None, word_line_vals=None, surface_dims, diagram_pos, segment_length, color_bar_pos, color_bar_dims = \ plotting.crossbar.dimensions(crossbar_shape, width_mm=kwargs.get('width')) - if kwargs.get('allow_overwrite'): - filename = '{}.pdf'.format(kwargs.get('filename')) - filename = sanitize_filepath(filename) - else: - filename = utils.unique_path(kwargs.get('filename'), 'pdf') + + filename = plotting.utils.get_filepath(kwargs.get('filename'), + kwargs.get('allow_overwrite')) + surface = cairo.PDFSurface(filename, *surface_dims) context = cairo.Context(surface) @@ -191,20 +177,7 @@ def nodes(word_line_vals=None, bit_line_vals=None, voltages=None, **kwargs): width : float, optional Width of the diagram in millimeters. """ - kwargs.setdefault('default_color', (0, 0, 0)) - kwargs.setdefault('wire_scaling_factor', 1) - kwargs.setdefault('device_scaling_factor', 1) - kwargs.setdefault('node_scaling_factor', 1.4) - kwargs.setdefault('axis_label', 'Voltage (V)') - kwargs.setdefault('low_rgb', (213/255, 94/255, 0/255)) - kwargs.setdefault('zero_rgb', (235/255, 235/255, 235/255)) - kwargs.setdefault('high_rgb', (0/255, 114/255, 178/255)) - kwargs.setdefault('allow_overwrite', False) - kwargs.setdefault('filename', 'crossbar-voltages') - kwargs.setdefault('device_type', 'memristor') - kwargs.setdefault('significant_figures', 2) - kwargs.setdefault('round_crossings', True) - kwargs.setdefault('width', 210) + kwargs = plotting.utils.set_defaults(kwargs, False) if voltages is not None: word_line_vals = voltages.word_line @@ -219,11 +192,10 @@ def nodes(word_line_vals=None, bit_line_vals=None, voltages=None, **kwargs): surface_dims, diagram_pos, segment_length, color_bar_pos, color_bar_dims = \ plotting.crossbar.dimensions(crossbar_shape, width_mm=kwargs.get('width')) - if kwargs.get('allow_overwrite'): - filename = '{}.pdf'.format(kwargs.get('filename')) - filename = sanitize_filepath(filename) - else: - filename = utils.unique_path(kwargs.get('filename'), 'pdf') + + filename = plotting.utils.get_filepath(kwargs.get('filename'), + kwargs.get('allow_overwrite')) + surface = cairo.PDFSurface(filename, *surface_dims) context = cairo.Context(surface) diff --git a/badcrossbar/plotting/utils.py b/badcrossbar/plotting/utils.py index 35e8717..87de0b3 100644 --- a/badcrossbar/plotting/utils.py +++ b/badcrossbar/plotting/utils.py @@ -1,6 +1,7 @@ import numpy as np import numpy.lib.recfunctions as nlr from sigfig import round +from badcrossbar import utils def complete_path(ctx, rgb=(0, 0, 0), width=1): @@ -163,3 +164,66 @@ def arrays_range(*arrays, sf=2): high = maximum_absolute return low, high + +def set_defaults(kwargs, branches=True): + """Sets default values for kwargs arguments in `badcrossbar.plot` functions. + + Parameters + ---------- + kwargs : dict of any + Optional keyword arguments. + branches : bool + Whether branches are being plotted. If `False`, it is assumed that + nodes are being plotted. + + Returns + ---------- + dict of any + Optional keyword arguments with the default values set. + """ + kwargs.setdefault('default_color', (0, 0, 0)) + kwargs.setdefault('wire_scaling_factor', 1) + kwargs.setdefault('device_scaling_factor', 1) + kwargs.setdefault('axis_label', 'Current (A)') + kwargs.setdefault('low_rgb', (213/255, 94/255, 0/255)) + kwargs.setdefault('zero_rgb', (235/255, 235/255, 235/255)) + kwargs.setdefault('high_rgb', (0/255, 114/255, 178/255)) + kwargs.setdefault('allow_overwrite', False) + kwargs.setdefault('device_type', 'memristor') + kwargs.setdefault('significant_figures', 2) + kwargs.setdefault('round_crossings', True) + kwargs.setdefault('width', 210) + if branches: + kwargs.setdefault('node_scaling_factor', 1) + kwargs.setdefault('filename', 'crossbar-currents') + else: + kwargs.setdefault('node_scaling_factor', 1.4) + kwargs.setdefault('filename', 'crossbar-voltages') + + return kwargs + +def get_filepath(filename, allow_overwrite): + """Constructs filepath of the diagram. + + Parameters + ---------- + filename : str + Filename (without the extension). + allow_overwrite : + If True, can overwrite existing PDF files with the same name. + + Returns + ---------- + str + Filepath of the diagram. + """ + extension = 'pdf' + + if allow_overwrite: + filepath = '{}.{}'.format(filename, extension) + filepath = sanitize_filepath(filepath) + else: + filepath = utils.unique_path(filename, extension) + + return filepath + diff --git a/setup.py b/setup.py index 4e42657..8603ae1 100644 --- a/setup.py +++ b/setup.py @@ -23,7 +23,7 @@ def load_requirements(): setup( name='badcrossbar', - version='1.0.0', + version='1.0.1', packages=['badcrossbar', 'badcrossbar.computing', 'badcrossbar.plotting', 'tests'], install_requires=load_requirements(), diff --git a/tests/test_fill.py b/tests/test_fill.py new file mode 100644 index 0000000..c7a1028 --- /dev/null +++ b/tests/test_fill.py @@ -0,0 +1,66 @@ +import badcrossbar.computing as computing +import numpy as np +from scipy.sparse import lil_matrix +from collections import namedtuple +import copy +import pytest + +Interconnect = namedtuple('Interconnect', ['word_line', 'bit_line']) + +applied_voltages_list = [ + np.array([ + [5]]), + np.array([ + [5, 10, -4]]), + np.array([ + [5], + [7]]), + np.array([ + [7, 11, 13], + [-2, 0, 5]])] + +resistances_list = [ + np.ones((1,1)), + np.ones((1,1)), + np.array([[10, 20], [30, 40]]), + np.ones((2,2))] + +r_i_list = [ + Interconnect(0.5, 0), + Interconnect(0.5, 0.25), + Interconnect(0, 0.25), + Interconnect(0.5, 0.25)] + +# i +i_expected = [ + np.array([ + [10]]), + np.array([ + [10, 20, -8], + [0, 0, 0]]), + np.array([ + [5/10], + [5/20], + [7/30], + [7/40]]), + np.array([ + [14, 22, 26], + [0, 0, 0], + [-4, 0, 10], + [0, 0, 0], + [0, 0, 0], + [0, 0, 0], + [0, 0, 0], + [0, 0, 0]])] + +i_inputs = [i for i in zip( + applied_voltages_list, resistances_list, r_i_list, i_expected)] + + +@pytest.mark.parametrize('applied_voltages,resistances,r_i,expected', i_inputs) +def test_i(applied_voltages, resistances, r_i, expected): + """Tests `badcrossbar.computing.fill.i()`. + """ + i_matrix = computing.fill.i(applied_voltages, resistances, r_i) + np.testing.assert_array_almost_equal(i_matrix, expected) + diff --git a/tests/test_kcl.py b/tests/test_kcl.py new file mode 100644 index 0000000..5a1134e --- /dev/null +++ b/tests/test_kcl.py @@ -0,0 +1,95 @@ +import badcrossbar.computing as computing +import numpy as np +from scipy.sparse import lil_matrix +from collections import namedtuple +import copy +import pytest + +Interconnect = namedtuple('Interconnect', ['word_line', 'bit_line']) +r_i = Interconnect(0.5, 0.25) + +conductances_list = [ + np.array([[100], [0]]), + np.array([[20, 50]]), + np.array([[42]]), + np.array([[0, 20], [30, 0]])] + +g_matrices = [ + lil_matrix((4, 4)), + lil_matrix((4, 4)), + lil_matrix((2, 2)), + lil_matrix((8, 8))] + +# word_line_nodes +word_line_nodes_expected = [ + np.array([ + [102, 0, -100, 0], + [0, 2, 0, 0], + [0, 0, 0, 0], + [0, 0, 0, 0]]), + np.array([ + [24, -2, -20, 0], + [-2, 52, 0, -50], + [0, 0, 0, 0], + [0, 0, 0, 0]]), + np.array([ + [44, -42], + [0, 0]]), + np.array([ + [4, -2, 0, 0, 0, 0, 0, 0], + [-2, 22, 0, 0, 0, -20, 0, 0], + [0, 0, 34, -2, 0, 0, -30, 0], + [0, 0, -2, 2, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0]])] + +word_line_nodes_inputs = [i for i in zip( + conductances_list, g_matrices, word_line_nodes_expected)] + +# bit_line_nodes +bit_line_nodes_expected = [ + np.array([ + [0, 0, 0, 0], + [0, 0, 0, 0], + [-100, 0, 104, -4], + [0, 0, -4, 8]]), + np.array([ + [0, 0, 0, 0], + [0, 0, 0, 0], + [-20, 0, 24, 0], + [0, -50, 0, 54]]), + np.array([ + [0, 0], + [-42, 46]]), + np.array([ + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 4, 0, -4, 0], + [0, -20, 0, 0, 0, 24, 0, -4], + [0, 0, -30, 0, -4, 0, 38, 0], + [0, 0, 0, 0, 0, -4, 0, 8]])] + +bit_line_nodes_inputs = [i for i in zip( + conductances_list, g_matrices, bit_line_nodes_expected)] + +@pytest.mark.parametrize('conductances,g_matrix,expected', word_line_nodes_inputs) +def test_word_line_nodes(conductances, g_matrix, expected): + """Tests `badcrossbar.computing.kcl.word_line_nodes()`. + """ + filled_g_matrix = computing.kcl.word_line_nodes( + copy.deepcopy(g_matrix), conductances, r_i).toarray() + np.testing.assert_array_almost_equal(filled_g_matrix, expected) + + +@pytest.mark.parametrize('conductances,g_matrix,expected', bit_line_nodes_inputs) +def test_bit_line_nodes(conductances, g_matrix, expected): + """Tests `badcrossbar.computing.kcl.bit_line_nodes()`. + """ + filled_g_matrix = computing.kcl.bit_line_nodes( + copy.deepcopy(g_matrix), conductances, r_i).toarray() + np.testing.assert_array_almost_equal(filled_g_matrix, expected) +