Skip to content

Commit

Permalink
Merge pull request #53 from intezer/fix/improve-download-file
Browse files Browse the repository at this point in the history
fix(download file): allow passing file-like object to download file, …
  • Loading branch information
davidt99 authored Jul 13, 2022
2 parents 3c64cb6 + 1304262 commit 4707a75
Show file tree
Hide file tree
Showing 6 changed files with 173 additions and 19 deletions.
5 changes: 5 additions & 0 deletions CHANGES
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
1.9.2
-------
- Allow passing file-like object to download file
- When providing to download file a directory, the file name is taken from the response

1.9.1
-------
- Optional latest family search on get analysis metadata
Expand Down
2 changes: 1 addition & 1 deletion intezer_sdk/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '1.9.1'
__version__ = '1.9.2'
11 changes: 9 additions & 2 deletions intezer_sdk/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
from http import HTTPStatus
from typing import BinaryIO
from typing import IO
from typing import Optional

import requests
Expand Down Expand Up @@ -143,8 +144,14 @@ def _init_sub_analyses(self):
else:
self._sub_analyses.append(sub_analysis_object)

def download_file(self, path: str):
self._api.download_file_by_sha256(self.result()['sha256'], path)
def download_file(self, path: str = None, output_stream: IO = None):
"""
Downloads the analysis's file.
`path` or `output_stream` must be provided.
:param path: A path to where to save the file, it can be either a directory or non-existing file path.
:param output_stream: A file-like object to write the file's content to.
"""
self._api.download_file_by_sha256(self.result()['sha256'], path, output_stream)

@property
def iocs(self) -> dict:
Expand Down
51 changes: 37 additions & 14 deletions intezer_sdk/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Any
from typing import BinaryIO
from typing import Dict
from typing import IO
from typing import List
from typing import Optional
from typing import Union
Expand Down Expand Up @@ -78,7 +79,8 @@ def _request(self,
path: str,
data: dict = None,
headers: dict = None,
files: dict = None) -> Response:
files: dict = None,
stream: bool = None) -> Response:
if not self._session:
self.set_session()

Expand All @@ -88,14 +90,16 @@ def _request(self,
self.full_url + path,
files=files,
data=data or {},
headers=headers or {}
headers=headers or {},
stream=stream
)
else:
response = self._session.request(
method,
self.full_url + path,
json=data or {},
headers=headers
headers=headers,
stream=stream
)

return response
Expand All @@ -105,13 +109,14 @@ def request_with_refresh_expired_access_token(self,
path: str,
data: dict = None,
headers: dict = None,
files: dict = None) -> Response:
files: dict = None,
stream: bool = None) -> Response:
response = self._request(method, path, data, headers, files)

if response.status_code == HTTPStatus.UNAUTHORIZED:
self._access_token = None
self.set_session()
response = self._request(method, path, data, headers, files)
response = self._request(method, path, data, headers, files, stream)

return response

Expand Down Expand Up @@ -335,20 +340,38 @@ def get_url_result(self, url: str) -> Optional[Response]:

return response

def download_file_by_sha256(self, sha256: str, path: str) -> None:
if os.path.isdir(path):
path = os.path.join(path, sha256 + '.sample')
if os.path.isfile(path):
raise FileExistsError()
def download_file_by_sha256(self, sha256: str, path: str = None, output_stream: IO = None) -> None:
if not path and not output_stream:
raise ValueError('You must provide either path or output_stream')
elif path and output_stream:
raise ValueError('You must provide either path or output_stream, not both')

should_extract_name_from_request = False
if path:
if os.path.isdir(path):
should_extract_name_from_request = True
elif os.path.isfile(path):
raise FileExistsError()

response = self.request_with_refresh_expired_access_token(path='/files/{}/download'.format(sha256),
method='GET')
method='GET',
stream=bool(path))

raise_for_status(response)
if output_stream:
output_stream.write(response.content)
else:
if should_extract_name_from_request:
try:
filename = response.headers['content-disposition'].split('filename=')[1]
except Exception:
filename = f'{sha256}.sample'

path = os.path.join(path, filename)

with open(path, 'wb') as file:
for chunk in response.iter_content(chunk_size=8192):
file.write(chunk)
with open(path, 'wb') as file:
for chunk in response.iter_content(chunk_size=8192):
file.write(chunk)

def index_by_sha256(self, sha256: str, index_as: IndexType, family_name: str = None) -> Response:
data = {'index_as': index_as.value}
Expand Down
11 changes: 9 additions & 2 deletions intezer_sdk/sub_analysis.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import datetime
from typing import IO
from typing import Optional
from typing import Union

Expand Down Expand Up @@ -153,5 +154,11 @@ def _handle_operation(self,

return self._operations[operation]

def download_file(self, path: str):
self._api.download_file_by_sha256(self.sha256, path)
def download_file(self, path: str = None, output_stream: IO = None):
"""
Downloads the analysis's file.
`path` or `output_stream` must be provided.
:param path: A path to where to save the file, it can be either a directory or non-existing file path.
:param output_stream: A file-like object to write the file's content to.
"""
self._api.download_file_by_sha256(self.sha256, path, output_stream)
112 changes: 112 additions & 0 deletions tests/unit/test_analysis.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import datetime
import io
import json
import os
import tempfile
import uuid
from http import HTTPStatus
from unittest.mock import mock_open
Expand Down Expand Up @@ -932,6 +935,115 @@ def test_get_analysis_by_id_raises_when_analysis_is_queued(self):
with self.assertRaises(errors.AnalysisIsStillRunningError):
FileAnalysis.from_analysis_id(analysis_id)

def test_download_file_path_uses_content_disposition(self):
# Arrange
file_hash = 'hash'
result = {'result': {'analysis_id': 'analysis_id', 'sha256': file_hash}}
file_name = 'a.sample'

with responses.RequestsMock() as mock:
mock.add('GET', url=f'{self.full_url}/files/{file_hash}', status=200, json=result)
mock.add('GET', url=f'{self.full_url}/files/{file_hash}/download',
status=200,
body=b'asd',
headers={'content-disposition': f'inline; filename={file_name}'})

analysis = FileAnalysis.from_latest_hash_analysis(file_hash)
with tempfile.TemporaryDirectory() as temp_dir:
# Act
analysis.download_file(temp_dir)

# Assert
files = os.listdir(temp_dir)
self.assertEqual(file_name, files[0])

def test_download_file_path_uses_default_file_name(self):
# Arrange
file_hash = 'hash'
result = {'result': {'analysis_id': 'analysis_id', 'sha256': file_hash}}
file_name = f'{file_hash}.sample'

with responses.RequestsMock() as mock:
mock.add('GET', url=f'{self.full_url}/files/{file_hash}', status=200, json=result)
mock.add('GET', url=f'{self.full_url}/files/{file_hash}/download', status=200, body=b'asd')

analysis = FileAnalysis.from_latest_hash_analysis(file_hash)
with tempfile.TemporaryDirectory() as temp_dir:
# Act
analysis.download_file(temp_dir)

# Assert
files = os.listdir(temp_dir)
self.assertEqual(file_name, files[0])

def test_download_file_path(self):
# Arrange
file_hash = 'hash'
result = {'result': {'analysis_id': 'analysis_id', 'sha256': file_hash}}
content = b'asd'

with responses.RequestsMock() as mock:
mock.add('GET', url=f'{self.full_url}/files/{file_hash}', status=200, json=result)
mock.add('GET', url=f'{self.full_url}/files/{file_hash}/download', status=200, body=content)

analysis = FileAnalysis.from_latest_hash_analysis(file_hash)
with tempfile.TemporaryDirectory() as temp_dir:
file_path = os.path.join(temp_dir, f'{file_hash}.sample')

# Act
analysis.download_file(file_path)
with open(file_path, 'rb') as f:
# Assert
self.assertEqual(content, f.read())

def test_download_file_output_stream(self):
# Arrange
file_hash = 'hash'
result = {'result': {'analysis_id': 'analysis_id', 'sha256': file_hash}}
content = b'asd'

with responses.RequestsMock() as mock:
mock.add('GET', url=f'{self.full_url}/files/{file_hash}', status=200, json=result)
mock.add('GET', url=f'{self.full_url}/files/{file_hash}/download', status=200, body=content)

analysis = FileAnalysis.from_latest_hash_analysis(file_hash)
output_stream = io.BytesIO()

# Act
analysis.download_file(output_stream=output_stream)
output_stream.seek(0, 0)

self.assertEqual(content, output_stream.read())

def test_download_file_raises_when_providing_output_stream_and_path(self):
# Arrange
file_hash = 'hash'
result = {'result': {'analysis_id': 'analysis_id', 'sha256': file_hash}}

with responses.RequestsMock() as mock:
mock.add('GET', url=f'{self.full_url}/files/{file_hash}', status=200, json=result)

analysis = FileAnalysis.from_latest_hash_analysis(file_hash)
output_stream = io.BytesIO()

# Act and Assert
with self.assertRaises(ValueError):
analysis.download_file(path='asd', output_stream=output_stream)

def test_download_file_raises_when_not_providing_output_stream_and_path(self):
# Arrange
file_hash = 'hash'
result = {'result': {'analysis_id': 'analysis_id', 'sha256': file_hash}}

with responses.RequestsMock() as mock:
mock.add('GET', url=f'{self.full_url}/files/{file_hash}', status=200, json=result)

analysis = FileAnalysis.from_latest_hash_analysis(file_hash)

# Act and Assert
with self.assertRaises(ValueError):
analysis.download_file()


class UrlAnalysisSpec(BaseTest):
def test_get_analysis_by_id_analysis_object_when_latest_analysis_found(self):
Expand Down

0 comments on commit 4707a75

Please sign in to comment.