Skip to content

Commit

Permalink
add mocked unittest
Browse files Browse the repository at this point in the history
  • Loading branch information
xiaohanzhan-db committed Dec 2, 2023
1 parent aaed9be commit df4a6d4
Showing 1 changed file with 53 additions and 0 deletions.
53 changes: 53 additions & 0 deletions tests/test_convert_delta_to_json.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# copyright 2022 mosaicml llm foundry authors
# spdx-license-identifier: apache-2.0

import pytest
import os
import sys
import warnings

# Add repo root to path so we can import scripts and test it
repo_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
sys.path.append(repo_dir)

import unittest
from unittest.mock import patch, MagicMock, mock_open
from scripts.data_prep.convert_delta_to_json import stream_delta_to_json

class TestStreamDeltaToJson(unittest.TestCase):

@patch('scripts.data_prep.convert_delta_to_json.sql.connect')
@patch('scripts.data_prep.convert_delta_to_json.pd.DataFrame.to_json')
def test_stream_delta_to_json(self, mock_to_json, mock_connect):
mock_args = MagicMock()
mock_args.DATABRICKS_HOST = 'test_host'
mock_args.DATABRICKS_TOKEN = 'test_token'
mock_args.http_path = 'test_http_path'
mock_args.delta_table_name = 'test_table'
mock_args.json_output_path = 'test_output_path'

# Mock database connection and cursor
mock_cursor = MagicMock()
mock_connection = MagicMock()
mock_connection.cursor.return_value = mock_cursor
mock_connect.return_value = mock_connection

# Mock fetchall response
count_response = MagicMock()
count_response.asDict.return_value = {'COUNT(*)': 3}
column_response_item = MagicMock()
column_response_item.asDict.return_value = {'COLUMN_NAME': 'name'} # Assuming SHOW COLUMNS query returns this format
data_response_item = MagicMock()
data_response_item.asDict.return_value = {'name': 'test', 'id': 1} # Assuming SELECT query returns this format
mock_cursor.fetchall.side_effect = [[count_response], [column_response_item], [data_response_item]]

stream_delta_to_json(mock_args)

mock_connect.assert_called_once_with(server_hostname='test_host', http_path='test_http_path', access_token='test_token')
mock_to_json.assert_called()
mock_cursor.close.assert_called()
mock_connection.close.assert_called()

if __name__ == '__main__':
unittest.main()

0 comments on commit df4a6d4

Please sign in to comment.