-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathextract_SLP_files
173 lines (132 loc) · 5.02 KB
/
extract_SLP_files
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
import os
import subprocess
from pathlib import Path
# TODO: These modules need to be pip installed into the vmassign_venv at the moment
import pandas as pd
from tqdm import tqdm
from vmassign import config, SpannerDatabase
def add_gcloud_to_path(where_gcloud):
os.environ["PATH"] += os.pathsep + str(Path(where_gcloud).parent)
def gcloud_in_path():
try:
subprocess.run(["gcloud", "--help", "|", "cat"], check=True, shell=True)
except FileNotFoundError:
return False
return True
def raise_if_gcloud_not_in_path(where_gcloud=None):
if where_gcloud is not None:
add_gcloud_to_path(where_gcloud)
if not gcloud_in_path():
message = (
"gcloud not found in PATH. "
"Please install Google Cloud SDK if you have not already. "
"Pass the path returned by `where gcloud` to the variable where_gcloud."
" Then retry running this script."
)
raise FileNotFoundError(message)
def copy_remote_files_to_dir(vm_name, remote_files, local_dir):
local_dir = Path(local_dir)
local_files = []
for remote_file in remote_files:
remote_file = Path(remote_file)
local_file = local_dir / f"{vm_name}.{remote_file.name}"
cmd = [
"gcloud",
"compute",
"scp",
f"{vm_name}:{remote_file.as_posix()}",
f"{local_file.as_posix()}",
"--zone",
config.VM_ZONE,
]
print(f'Running command: {" ".join(cmd)}')
try:
subprocess.run(" ".join(cmd), check=True, shell=True)
local_files.append(local_file.as_posix())
except subprocess.CalledProcessError as e:
print(f"Error running command: {e.stderr}")
raise
return local_files
def get_assigned_vms_from_database():
project_id = config.PROJECT_ID
instance_id = config.DB_INSTANCE_ID
database_id = config.DB_DATABASE_ID
table_name = config.DB_TABLE_NAME
spanner_db = SpannerDatabase.load_database(
project_id, instance_id, database_id, table_name
)
assigned_vms = spanner_db.get_assigned_vms()
user_by_vm = {}
for vm_name in assigned_vms:
user = spanner_db.get_user_email(hostname=vm_name)
user_by_vm[vm_name] = user
return user_by_vm
def find_label_files(vm_name: str, remote_dir: str):
remote_dir = Path(remote_dir)
# cmd = [
# "gcloud",
# "compute",
# "ssh",
# vm_name,
# "--zone",
# config.VM_ZONE,
# f"--command=find {remote_dir.as_posix()} -name '*.slp' -not -path '*/models/*' -not -path '*/predictions/*'",
# ]
cmd = [
f"gcloud compute ssh {vm_name} "
f"--zone {config.VM_ZONE} "
f"--command='find {remote_dir.as_posix()} -name \"*.slp\" "
f"-not -path \"*/models/*\" -not -path \"*/predictions/*\"'"
]
print(f'Running command: {cmd}')
try:
result = subprocess.run(cmd, shell=True, check=True, capture_output=True)
# Get list of SLP files (last character is newline, so remove it)
slp_files = result.stdout[:-1].decode("utf-8").split("\n")
return slp_files if slp_files != [''] else []
except subprocess.CalledProcessError as e:
print(f"Error running command: {e.stderr}")
return None
def save_to_csv(data, local_dir):
df = pd.DataFrame(data)
csv_file = Path(local_dir) / "vm_data.csv"
df.to_csv(csv_file, index=False)
def test_main(remote_dir, local_dir, vm_name):
# Add gcloud to PATH if where_gcloud is provided
# raise_if_gcloud_not_in_path(where_gcloud=where_gcloud)
# Create dictionary to store data
# data = {"filename": [], "vm-hostname": [], "user-email": []}
data = {"filename": [], "vm-hostname": []}
# Get data for each VM
try:
# Get list of SLP files in remote directory
remote_files = find_label_files(vm_name, remote_dir)
# Print the list of SLP files
if remote_files:
for file in remote_files:
print(file)
else:
print("No SLP files found.")
# Copy remote files to local directory
local_files = copy_remote_files_to_dir(
vm_name=vm_name, remote_files=remote_files, local_dir=local_dir
)
# Store data
for local_file in local_files:
data["filename"].append(local_file)
data["vm-hostname"].append(vm_name)
except Exception as e:
raise e
finally:
save_to_csv(data, local_dir)
if __name__ == "__main__":
vm_name = "justinshen-vm-001-20240814090614"
remote_dir = "/home/liezl/sleap-datasets/drosophila-melanogaster-courtship"
local_dir = "/Users/justinshen/Downloads/vmdata"
where_gcloud = r"path\to\Google\Cloud SDK\google-cloud-sdk\bin\gcloud"
test_main(
vm_name = vm_name,
remote_dir=remote_dir,
local_dir=local_dir,
# where_gcloud=where_gcloud, # TODO: Pass in output of `where gcloud` if needed
)