Skip to content

Commit

Permalink
Merge pull request Azure#9568 from V1ManagedServices/master
Browse files Browse the repository at this point in the history
token 403 enhancement
  • Loading branch information
v-atulyadav authored Jan 3, 2024
2 parents 775836a + 08f7ee9 commit 730c2b3
Show file tree
Hide file tree
Showing 10 changed files with 87 additions and 15 deletions.
Binary file not shown.
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ def main(oatFileMsg: func.QueueMessage) -> None:
f'Account token not found, clp: {message.clp_id}, stop current job.'
)
return

if utils.check_token_is_expired(account_token):
logger.error(f"token is expired, clp: {message.clp_id}")
return

oat_file = download_oat_file(
account_token, message.package_id, message.pipeline_id
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ def main(
if not token:
logger.warning(f'Account token not found, clp: {clp_id}, stop current job.')
return

if utils.check_token_is_expired(token):
logger.error(f"token is expired, clp: {clp_id}")
return

pipeline_id = oat_service.get_oat_pipeline_id(clp_id)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,15 @@ def main(rcaMsg: func.QueueMessage) -> None:
if not token:
raise GeneralException(f'Token not found for clp: {clp_id}')

if utils.check_token_is_expired(token):
logging.error(f"token is expired, clp: {clp_id}")
return

rca_task_detail = get_rca_task_detail(token, task_id, target_guid)

if not rca_task_detail:
logging.error(f"No rca_task_detail for clp: {clp_id}")
return

target_info = {
'xdrCustomerID': clp_id,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,15 +118,17 @@ def customize_json(clp_id, workbench_detail, workbench_record):
return xdr_log


def build_queue_message(clp_id, workbench_id, task_id, task_name, target_guid, target_info):
def build_queue_message(
clp_id, workbench_id, task_id, task_name, target_guid, target_info
):
return json.dumps(
{
'clp_id': clp_id,
'workbench_id': workbench_id,
'task_id': task_id,
'task_name': task_name,
'target_guid': target_guid,
'target_info': target_info
'target_info': target_info,
}
)

Expand All @@ -144,9 +146,13 @@ def main(wbMsg: func.QueueMessage, rcaMsg: func.Out[typing.List[str]]) -> None:
if not token:
raise GeneralException(f'Token not found for clp: {clp_id}')

if utils.check_token_is_expired(token):
logging.error(f"token is expired, clp: {clp_id}")
return

# get workbench detail
workbench_detail = get_workbench_detail(token, workbench_id)

if not workbench_detail:
logging.warning(
f'Could not get workbench data. Workbench id: {workbench_id}.'
Expand All @@ -166,9 +172,11 @@ def main(wbMsg: func.QueueMessage, rcaMsg: func.Out[typing.List[str]]) -> None:
rca_tasks = []
rac_task_log = []


# get rca task
rca_raw_tasks = get_rca_task(token, workbench_id,)
rca_raw_tasks = get_rca_task(
token,
workbench_id,
)

for task in rca_raw_tasks:
task_status = task['status']
Expand All @@ -177,9 +185,11 @@ def main(wbMsg: func.QueueMessage, rcaMsg: func.Out[typing.List[str]]) -> None:
f'Get rca task with status: {task_status}, Workbench id: {workbench_id}. No need to get rca detail.'
)
continue

# process prca task info
rac_task_log.append(transform_utils.transform_rca_task(clp_id, workbench_id ,task))
rac_task_log.append(
transform_utils.transform_rca_task(clp_id, workbench_id, task)
)

for target in task['targets']:
target_status = target['targetStatus']
Expand All @@ -194,14 +204,21 @@ def main(wbMsg: func.QueueMessage, rcaMsg: func.Out[typing.List[str]]) -> None:

rca_tasks.append(
build_queue_message(
clp_id, workbench_id, task['id'], task['name'], target['guid'], target_info
clp_id,
workbench_id,
task['id'],
task['name'],
target['guid'],
target_info,
)
)

if len(rac_task_log) > 0:
log_analytics = LogAnalytics(WORKSPACE_ID, WORKSPACE_KEY, RCA_TASK_LOG_TYPE)
log_analytics.post_data(rac_task_log)
logging.info(f'Send prca task data successfully. Workbench id: {workbench_id}, Count: {len(rac_task_log)}.')
logging.info(
f'Send prca task data successfully. Workbench id: {workbench_id}, Count: {len(rac_task_log)}.'
)

if rca_tasks:
rcaMsg.set(rca_tasks)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ def get_trace_log(headers):


def update_oat_pipeline_config(token: str, patch_data: Dict[str, Any]) -> None:
# See: https://adc.github.trendmicro.com/pages/CoreTech-SG/xdr-doc/?urls.primaryName=public-merged-beta#/Observed%20Attack%20Techniques%20Pipeline/patch_beta_xdr_oat_dataPipeline
url = f"{XDR_HOST_URL}/beta/xdr/oat/dataPipeline"
headers = get_header(
{
Expand Down Expand Up @@ -118,6 +117,9 @@ def get_oat_package_list(
f'start_time: {start_time}, end_time: {end_time}'
)
return 0, []
if response.status_code in [requests.codes.forbidden, requests.codes.not_found]:
logger.error(f"response status code: {response.status_code}")
return 0, []

response.raise_for_status()

Expand Down Expand Up @@ -165,6 +167,9 @@ def download_oat_file(
f'The OAT file is out of retention time, file_id: {oat_file_id}'
)
return None
if response.status_code in [requests.codes.forbidden, requests.codes.not_found]:
logger.error(f"response status code: {response.status_code}")
return None

response.raise_for_status()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,11 @@ def get_workbench_list(token, start_time, end_time, offset=0, limit=200):
f'Get workbench list response: {response.text}'
f'Get workbench list trace: {get_trace_log(response.headers)}'
)

if response.status_code in [requests.codes.forbidden, requests.codes.not_found]:
logger.error(f"response status code: {response.status_code}")
return 0, []

response.raise_for_status()
response_data = response.json()

Expand All @@ -89,6 +94,11 @@ def get_workbench_detail(token, workbench_id):
f'Get workbench detail response: {response.text}.'
f'Get workbench detail trace: {get_trace_log(response.headers)}'
)

if response.status_code in [requests.codes.forbidden, requests.codes.not_found]:
logger.error(f"response status code: {response.status_code}")
return []

response.raise_for_status()
response_data = response.json()

Expand Down Expand Up @@ -130,6 +140,11 @@ def get_rca_task(token, workbench_id):
f'Get rca task response: {response.text}'
f'Get rca task trace: {get_trace_log(response.headers)}'
)

if response.status_code in [requests.codes.forbidden, requests.codes.not_found]:
logger.error(f"response status code: {response.status_code}")
return []

response.raise_for_status()
response_data = response.json()

Expand All @@ -153,6 +168,11 @@ def get_rca_task_detail(token, task_id, endpoint_guid):
f'Get rca detail response: {response.text}'
f'Get rca detail trace: {get_trace_log(response.headers)}'
)

if response.status_code in [requests.codes.forbidden, requests.codes.not_found]:
logger.error(f"response status code: {response.status_code}")
return []

response.raise_for_status()
response_data = response.json()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,16 @@ def find_token_by_clp(clp_id, api_tokens):
return next(filter(lambda token: get_clp_id(token) == clp_id, api_tokens), None)


def check_token_is_expired(token: str) -> bool:
try:
return datetime.now() > datetime.fromtimestamp(
jwt.decode(token, options={"verify_signature": False}).get('et')
)
except Exception as e:
logger.error(f"if_token_is_expired checking Error, e: {e}")
return False


@timer
def get_last_success_time(table_name, clp_id):
try:
Expand Down Expand Up @@ -62,9 +72,6 @@ def update_last_success_time(table_name, clp_id, time):
)





@timer
def send_message_to_storage_queue(
queue_name, message, conn_str=STORAGE_CONNECTION_STRING
Expand All @@ -77,7 +84,7 @@ def send_message_to_storage_queue(

try:
queue_client.create_queue()
except ResourceExistsError as e:
logger.error(e)
except ResourceExistsError:
logger.info(f"Queue {queue_name} already exists.")

queue_client.send_message(message)
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@ def main(mytimer: func.TimerRequest, wbMsg: func.Out[typing.List[str]]) -> None:
clp_id = utils.get_clp_id(token)
health_check_data['clpId'] = clp_id

if utils.check_token_is_expired(token):
logging.error(f"token is expired, clp: {clp_id}")
continue

start_time, end_time = generate_time(table_service, clp_id)
start_time_str = start_time.strftime(DATETIME_FORMAT)
end_time_str = end_time.strftime(DATETIME_FORMAT)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,10 @@ def main(myTimer: func.TimerRequest, oatTaskMsg: func.Out[List[str]]) -> None:
clp_id = utils.get_clp_id(token)
health_check_data['clpId'] = clp_id

if utils.check_token_is_expired(token):
logger.error(f"token is expired, clp: {clp_id}")
continue

start_time, end_time = generate_time(clp_id, token)
if start_time is None or end_time is None:
logger.warning(
Expand Down

0 comments on commit 730c2b3

Please sign in to comment.