Skip to content

Commit

Permalink
Merge branch 'main' into lint/ruff-hail
Browse files Browse the repository at this point in the history
  • Loading branch information
iris-garden committed Feb 6, 2024
2 parents c77c6a2 + 671deef commit 66e3835
Show file tree
Hide file tree
Showing 181 changed files with 7,220 additions and 1,005 deletions.
4 changes: 4 additions & 0 deletions batch/batch/cloud/azure/worker/worker_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import aiohttp
import orjson
from aiohttp import web

from hailtop import httpx
from hailtop.aiocloud import aioazure
Expand Down Expand Up @@ -60,6 +61,9 @@ async def user_container_registry_credentials(self, credentials: Dict[str, str])
credentials = orjson.loads(base64.b64decode(credentials['key.json']).decode())
return {'username': credentials['appId'], 'password': credentials['password']}

def create_metadata_server_app(self, credentials: Dict[str, str]) -> web.Application:
raise NotImplementedError

def instance_config_from_config_dict(self, config_dict: Dict[str, str]) -> AzureSlimInstanceConfig:
return AzureSlimInstanceConfig.from_dict(config_dict)

Expand Down
12 changes: 8 additions & 4 deletions batch/batch/cloud/gcp/driver/create_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,18 +230,22 @@ def scheduling() -> dict:
- /batch/jvm-container-logs/jvm-*.log
record_log_file_path: true
processors:
parse_message:
type: parse_json
labels:
type: modify_fields
fields:
labels.namespace:
static_value: $NAMESPACE
labels.instance_id:
static_value: $INSTANCE_ID
severity:
move_from: jsonPayload.severity
service:
log_level: error
pipelines:
default_pipeline:
processors: [labels]
processors: [parse_message, labels]
receivers: [runlog, workerlog, jvmlog]
metrics:
Expand All @@ -262,9 +266,9 @@ def scheduling() -> dict:
iptables --table nat --append POSTROUTING --source 172.20.0.0/15 --jump MASQUERADE
# [public]
# Block public traffic to the metadata server
iptables --append FORWARD --source 172.21.0.0/16 --destination 169.254.169.254 --jump DROP
# But allow the internal gateway
# Send public jobs' metadata server requests to the batch worker itself
iptables --table nat --append PREROUTING --source 172.21.0.0/16 --destination 169.254.169.254 -p tcp -j REDIRECT --to-ports 5555
# Allow the internal gateway
iptables --append FORWARD --destination $INTERNAL_GATEWAY_IP --jump ACCEPT
# And this worker
iptables --append FORWARD --destination $IP_ADDRESS --jump ACCEPT
Expand Down
2 changes: 1 addition & 1 deletion batch/batch/cloud/gcp/instance_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def create(
GCPStaticSizedDiskResource.create(product_versions, 'pd-ssd', boot_disk_size_gb, region),
data_disk_resource,
GCPDynamicSizedDiskResource.create(product_versions, 'pd-ssd', region),
GCPIPFeeResource.create(product_versions, 1024),
GCPIPFeeResource.create(product_versions, 1024, preemptible),
GCPServiceFeeResource.create(product_versions),
GCPSupportLogsSpecsAndFirewallFees.create(product_versions),
]
Expand Down
9 changes: 5 additions & 4 deletions batch/batch/cloud/gcp/resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,17 +270,18 @@ class GCPIPFeeResource(IPFeeResourceMixin, GCPResource):
TYPE = 'gcp_ip_fee'

@staticmethod
def product_name(base: int) -> str:
return f'ip-fee/{base}'
def product_name(base: int, preemptible: bool) -> str:
preemptible_str = 'preemptible' if preemptible else 'nonpreemptible'
return f'ip-fee/{preemptible_str}/{base}'

@staticmethod
def from_dict(data: Dict[str, Any]) -> 'GCPIPFeeResource':
assert data['type'] == GCPIPFeeResource.TYPE
return GCPIPFeeResource(data['name'])

@staticmethod
def create(product_versions: ProductVersions, base: int) -> 'GCPIPFeeResource':
product = GCPIPFeeResource.product_name(base)
def create(product_versions: ProductVersions, base: int, preemptible: bool) -> 'GCPIPFeeResource':
product = GCPIPFeeResource.product_name(base, preemptible)
name = product_versions.resource_name(product)
assert name, product
return GCPIPFeeResource(name)
Expand Down
109 changes: 109 additions & 0 deletions batch/batch/cloud/gcp/worker/metadata_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
from aiohttp import web

from hailtop.aiocloud import aiogoogle

from ....globals import HTTP_CLIENT_MAX_SIZE


class AppKeys:
USER_CREDENTIALS = web.AppKey('credentials', aiogoogle.GoogleServiceAccountCredentials)
GCE_METADATA_SERVER_CLIENT = web.AppKey('ms_client', aiogoogle.GoogleMetadataServerClient)


async def root(_):
return web.Response(text='computeMetadata/\n')


async def project_id(request: web.Request):
metadata_server_client = request.app[AppKeys.GCE_METADATA_SERVER_CLIENT]
return web.Response(text=await metadata_server_client.project())


async def numeric_project_id(request: web.Request):
metadata_server_client = request.app[AppKeys.GCE_METADATA_SERVER_CLIENT]
return web.Response(text=await metadata_server_client.numeric_project_id())


async def service_accounts(request: web.Request):
gsa_email = request.app[AppKeys.USER_CREDENTIALS].email
return web.Response(text=f'default\n{gsa_email}\n')


async def user_service_account(request: web.Request):
gsa_email = request.app[AppKeys.USER_CREDENTIALS].email
recursive = request.query.get('recursive')
# https://cloud.google.com/compute/docs/metadata/querying-metadata
# token is not included in the recursive version, presumably as that
# is not simple metadata but requires requesting an access token
if recursive == 'true':
return web.json_response(
{
'aliases': ['default'],
'email': gsa_email,
'scopes': ['https://www.googleapis.com/auth/cloud-platform'],
},
)
return web.Response(text='aliases\nemail\nscopes\ntoken\n')


async def user_email(request: web.Request):
return web.Response(text=request.app[AppKeys.USER_CREDENTIALS].email)


async def user_token(request: web.Request):
access_token = await request.app[AppKeys.USER_CREDENTIALS]._get_access_token()
return web.json_response({
'access_token': access_token.token,
'expires_in': access_token.expires_in,
'token_type': 'Bearer',
})


@web.middleware
async def middleware(request: web.Request, handler):
credentials = request.app[AppKeys.USER_CREDENTIALS]
gsa = request.match_info.get('gsa')
if gsa and gsa not in (credentials.email, 'default'):
raise web.HTTPBadRequest()

response = await handler(request)
response.enable_compression()

# `gcloud` does not properly respect `charset`, which aiohttp automatically
# sets so we have to explicitly erase it
# See https://github.com/googleapis/google-auth-library-python/blob/b935298aaf4ea5867b5778bcbfc42408ba4ec02c/google/auth/compute_engine/_metadata.py#L170
if 'application/json' in response.headers['Content-Type']:
response.headers['Content-Type'] = 'application/json'
response.headers['Metadata-Flavor'] = 'Google'
response.headers['Server'] = 'Metadata Server for VM'
response.headers['X-XSS-Protection'] = '0'
response.headers['X-Frame-Options'] = 'SAMEORIGIN'
return response


def create_app(
credentials: aiogoogle.GoogleServiceAccountCredentials,
metadata_server_client: aiogoogle.GoogleMetadataServerClient,
) -> web.Application:
app = web.Application(
client_max_size=HTTP_CLIENT_MAX_SIZE,
middlewares=[middleware],
)
app[AppKeys.USER_CREDENTIALS] = credentials
app[AppKeys.GCE_METADATA_SERVER_CLIENT] = metadata_server_client

app.add_routes([
web.get('/', root),
web.get('/computeMetadata/v1/project/project-id', project_id),
web.get('/computeMetadata/v1/project/numeric-project-id', numeric_project_id),
web.get('/computeMetadata/v1/instance/service-accounts/', service_accounts),
web.get('/computeMetadata/v1/instance/service-accounts/{gsa}/', user_service_account),
web.get('/computeMetadata/v1/instance/service-accounts/{gsa}/email', user_email),
web.get('/computeMetadata/v1/instance/service-accounts/{gsa}/token', user_token),
])

async def close_credentials(_):
await credentials.close()

app.on_cleanup.append(close_credentials)
return app
35 changes: 22 additions & 13 deletions batch/batch/cloud/gcp/worker/worker_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,18 @@
import tempfile
from typing import Dict, List

import aiohttp
import orjson
from aiohttp import web

from hailtop import httpx
from hailtop.aiocloud import aiogoogle
from hailtop.auth.auth import IdentityProvider
from hailtop.utils import check_exec_output, retry_transient_errors
from hailtop.utils import check_exec_output

from ....worker.worker_api import CloudWorkerAPI, ContainerRegistryCredentials
from ..instance_config import GCPSlimInstanceConfig
from .disk import GCPDisk
from .metadata_server import create_app


class GCPWorkerAPI(CloudWorkerAPI):
Expand All @@ -24,14 +25,24 @@ class GCPWorkerAPI(CloudWorkerAPI):
async def from_env() -> 'GCPWorkerAPI':
project = os.environ['PROJECT']
zone = os.environ['ZONE'].rsplit('/', 1)[1]
compute_client = aiogoogle.GoogleComputeClient(project)
return GCPWorkerAPI(project, zone, compute_client)
worker_credentials = aiogoogle.GoogleInstanceMetadataCredentials()
http_session = httpx.ClientSession()
return GCPWorkerAPI(project, zone, worker_credentials, http_session)

def __init__(self, project: str, zone: str, compute_client: aiogoogle.GoogleComputeClient):
def __init__(
self,
project: str,
zone: str,
worker_credentials: aiogoogle.GoogleInstanceMetadataCredentials,
http_session: httpx.ClientSession,
):
self.project = project
self.zone = zone
self._compute_client = compute_client
self._http_session = http_session
self._metadata_server_client = aiogoogle.GoogleMetadataServerClient(http_session)
self._compute_client = aiogoogle.GoogleComputeClient(project)
self._gcsfuse_credential_files: Dict[str, str] = {}
self._worker_credentials = worker_credentials

@property
def cloud_specific_env_vars_for_user_jobs(self) -> List[str]:
Expand All @@ -53,13 +64,7 @@ def create_disk(self, instance_name: str, disk_name: str, size_in_gb: int, mount
)

async def worker_container_registry_credentials(self, session: httpx.ClientSession) -> ContainerRegistryCredentials:
token_dict = await retry_transient_errors(
session.post_read_json,
'http://169.254.169.254/computeMetadata/v1/instance/service-accounts/default/token',
headers={'Metadata-Flavor': 'Google'},
timeout=aiohttp.ClientTimeout(total=60), # type: ignore
)
access_token = token_dict['access_token']
access_token = await self._worker_credentials.access_token()
return {'username': 'oauth2accesstoken', 'password': access_token}

async def user_container_registry_credentials(self, credentials: Dict[str, str]) -> ContainerRegistryCredentials:
Expand All @@ -68,6 +73,10 @@ async def user_container_registry_credentials(self, credentials: Dict[str, str])
access_token = await sa_credentials.access_token()
return {'username': 'oauth2accesstoken', 'password': access_token}

def create_metadata_server_app(self, credentials: Dict[str, str]) -> web.Application:
key = orjson.loads(base64.b64decode(credentials['key.json']).decode())
return create_app(aiogoogle.GoogleServiceAccountCredentials(key), self._metadata_server_client)

def instance_config_from_config_dict(self, config_dict: Dict[str, str]) -> GCPSlimInstanceConfig:
return GCPSlimInstanceConfig.from_dict(config_dict)

Expand Down
2 changes: 1 addition & 1 deletion batch/batch/globals.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

BATCH_FORMAT_VERSION = 7
STATUS_FORMAT_VERSION = 5
INSTANCE_VERSION = 26
INSTANCE_VERSION = 27

MAX_PERSISTENT_SSD_SIZE_GIB = 64 * 1024
RESERVED_STORAGE_GB_PER_CORE = 5
34 changes: 20 additions & 14 deletions batch/batch/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import tempfile
import traceback
import uuid
import warnings
from collections import defaultdict
from contextlib import AsyncExitStack, ExitStack
from typing import (
Expand Down Expand Up @@ -95,19 +94,6 @@
with open('/subdomains.txt', 'r', encoding='utf-8') as subdomains_file:
HAIL_SERVICES = [line.rstrip() for line in subdomains_file.readlines()]

oldwarn = warnings.warn


def deeper_stack_level_warn(*args, **kwargs):
if 'stacklevel' in kwargs:
kwargs['stacklevel'] = max(kwargs['stacklevel'], 5)
else:
kwargs['stacklevel'] = 5
return oldwarn(*args, **kwargs)


warnings.warn = deeper_stack_level_warn


class BatchWorkerAccessLogger(AccessLogger):
def __init__(self, logger: logging.Logger, log_format: str):
Expand Down Expand Up @@ -263,6 +249,8 @@ async def init(self):
for service in HAIL_SERVICES:
hosts.write(f'{INTERNAL_GATEWAY_IP} {service}.hail\n')
hosts.write(f'{INTERNAL_GATEWAY_IP} internal.hail\n')
if CLOUD == 'gcp':
hosts.write('169.254.169.254 metadata metadata.google.internal')

# Jobs on the private network should have access to the metadata server
# and our vdc. The public network should not so we use google's public
Expand Down Expand Up @@ -760,6 +748,7 @@ def __init__(
command: List[str],
cpu_in_mcpu: int,
memory_in_bytes: int,
user_credentials: Optional[Dict[str, str]],
network: Optional[Union[bool, str]] = None,
port: Optional[int] = None,
timeout: Optional[int] = None,
Expand All @@ -777,6 +766,7 @@ def __init__(
self.command = command
self.cpu_in_mcpu = cpu_in_mcpu
self.memory_in_bytes = memory_in_bytes
self.user_credentials = user_credentials
self.network = network
self.port = port
self.timeout = timeout
Expand Down Expand Up @@ -820,6 +810,8 @@ def __init__(

self.monitor: Optional[ResourceUsageMonitor] = None

self.metadata_app_runner: Optional[web.AppRunner] = None

async def create(self):
self.state = 'creating'
try:
Expand Down Expand Up @@ -959,6 +951,9 @@ async def _cleanup(self):
if self._cleaned_up:
return

if self.metadata_app_runner:
await self.metadata_app_runner.cleanup()

assert self._run_fut is None
try:
if self.overlay_mounted:
Expand Down Expand Up @@ -1025,6 +1020,14 @@ async def _setup_network_namespace(self):
else:
assert self.network is None or self.network == 'public'
self.netns = await network_allocator.allocate_public()
if self.user_credentials and CLOUD == 'gcp':
assert CLOUD_WORKER_API
self.metadata_app_runner = web.AppRunner(
CLOUD_WORKER_API.create_metadata_server_app(self.user_credentials)
)
await self.metadata_app_runner.setup()
site = web.TCPSite(self.metadata_app_runner, self.netns.host_ip, 5555)
await site.start()
except asyncio.TimeoutError:
log.exception(network_allocator.task_manager.tasks)
raise
Expand Down Expand Up @@ -1454,6 +1457,7 @@ def copy_container(
cpu_in_mcpu=cpu_in_mcpu,
memory_in_bytes=memory_in_bytes,
volume_mounts=volume_mounts,
user_credentials=job.credentials,
stdin=json.dumps(files),
)

Expand Down Expand Up @@ -1778,6 +1782,7 @@ def __init__(
command=job_spec['process']['command'],
cpu_in_mcpu=self.cpu_in_mcpu,
memory_in_bytes=self.memory_in_bytes,
user_credentials=self.credentials,
network=job_spec.get('network'),
port=job_spec.get('port'),
timeout=job_spec.get('timeout'),
Expand Down Expand Up @@ -2536,6 +2541,7 @@ async def create_and_start(
command=command,
cpu_in_mcpu=n_cores * 1000,
memory_in_bytes=total_memory_bytes,
user_credentials=None,
env=[f'HAIL_WORKER_OFF_HEAP_MEMORY_PER_CORE_MB={off_heap_memory_per_core_mib}', f'HAIL_CLOUD={CLOUD}'],
volume_mounts=volume_mounts,
log_path=f'/batch/jvm-container-logs/jvm-{index}.log',
Expand Down
Loading

0 comments on commit 66e3835

Please sign in to comment.