diff --git a/target_csv.py b/target_csv.py index c372545..aff869f 100755 --- a/target_csv.py +++ b/target_csv.py @@ -16,6 +16,9 @@ from jsonschema.validators import Draft4Validator import singer +DEFAULT_PARENT_KEY = '' +SEP = '__' + logger = singer.get_logger() def emit_state(state): @@ -25,21 +28,33 @@ def emit_state(state): sys.stdout.write("{}\n".format(line)) sys.stdout.flush() -def flatten(d, parent_key='', sep='__'): +def flatten(d, parent_key=DEFAULT_PARENT_KEY, sep=SEP): items = [] for k, v in d.items(): - new_key = parent_key + sep + k if parent_key else k + new_key = generate_key(parent_key, sep, k) if isinstance(v, collections.MutableMapping): items.extend(flatten(v, new_key, sep=sep).items()) else: items.append((new_key, str(v) if type(v) is list else v)) return dict(items) - + +def get_headers(schema, parent_key=DEFAULT_PARENT_KEY, sep=SEP): + headers = [] + for k, v in schema.get('properties', {}).items(): + new_key = generate_key(parent_key, sep, k) + if v.get('type') == 'object': + headers = headers + get_headers(v, new_key, sep=sep) + else: + headers.append(new_key) + return list(headers) + +def generate_key(parent_key, sep, key): + return parent_key + sep + key if parent_key else key + def persist_messages(delimiter, quotechar, messages, destination_path): state = None schemas = {} key_properties = {} - headers = {} validators = {} now = datetime.now().strftime('%Y%m%dT%H%M%S') @@ -64,19 +79,10 @@ def persist_messages(delimiter, quotechar, messages, destination_path): flattened_record = flatten(o['record']) - if o['stream'] not in headers and not file_is_empty: - with open(filename, 'r') as csvfile: - reader = csv.reader(csvfile, - delimiter=delimiter, - quotechar=quotechar) - first_line = next(reader) - headers[o['stream']] = first_line if first_line else flattened_record.keys() - else: - headers[o['stream']] = flattened_record.keys() - + headers = get_headers(schemas[o['stream']]) with open(filename, 'a') as csvfile: writer = csv.DictWriter(csvfile, - headers[o['stream']], + headers, extrasaction='ignore', delimiter=delimiter, quotechar=quotechar) diff --git a/target_csv_test.py b/target_csv_test.py new file mode 100644 index 0000000..078f444 --- /dev/null +++ b/target_csv_test.py @@ -0,0 +1,37 @@ +import target_csv +import unittest + + +class TestTargetCsv(unittest.TestCase): + + def setUp(self): + pass + + def test_get_headers(self): + schema = {"properties": { + 'a': {'type': 'string'}, + 'b': {'type': 'array', 'items': {'type': 'string'}}, + 'c': {'type': 'object', 'properties': {'d': {'type': 'string'}}} + }} + + assert target_csv.get_headers(schema) == ['a', 'b', 'c__d'] + + def test_get_headers_matches_flatten(self): + schema = {'properties': { + 'a': {'type': 'string'}, + 'b': {'type': 'array', 'items': {'type': 'string'}}, + 'c': {'type': 'object', 'properties': {'d': {'type': 'string'}}} + }} + + record = { + 'a': 'alpha', + 'b': ['beta'], + 'c': {'d': 'delta'} + } + + keys_from_schema = target_csv.get_headers(schema) + keys_from_records = list(target_csv.flatten(record).keys()) + assert keys_from_schema == keys_from_records + +if __name__ == '__main__': + unittest.main()