Skip to content

Commit

Permalink
Fix loop logic flaws in loader (#203)
Browse files Browse the repository at this point in the history
  • Loading branch information
cybermaggedon authored Dec 9, 2024
1 parent 803f110 commit 6103127
Showing 1 changed file with 68 additions and 14 deletions.
82 changes: 68 additions & 14 deletions trustgraph-cli/scripts/tg-load-kg-core
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,13 @@ class Running:
def get(self): return self.running
def stop(self): self.running = False

ge_counts = 0
t_counts = 0

async def load_ge(running, queue, url):

global ge_counts

async with aiohttp.ClientSession() as session:

async with session.ws_connect(f"{url}load/graph-embeddings") as ws:
Expand All @@ -29,6 +34,11 @@ async def load_ge(running, queue, url):

try:
msg = await asyncio.wait_for(queue.get(), 1)

# End of load
if msg is None:
break

except:
# Hopefully it's TimeoutError. Annoying to match since
# it changed in 3.11.
Expand All @@ -45,10 +55,17 @@ async def load_ge(running, queue, url):
"entity": msg["e"],
}

await ws.send_json(msg)
try:
await ws.send_json(msg)
except Exception as e:
print(e)

ge_counts += 1

async def load_triples(running, queue, url):

global t_counts

async with aiohttp.ClientSession() as session:

async with session.ws_connect(f"{url}load/triples") as ws:
Expand All @@ -57,6 +74,11 @@ async def load_triples(running, queue, url):

try:
msg = await asyncio.wait_for(queue.get(), 1)

# End of load
if msg is None:
break

except:
# Hopefully it's TimeoutError. Annoying to match since
# it changed in 3.11.
Expand All @@ -72,27 +94,28 @@ async def load_triples(running, queue, url):
"triples": msg["t"],
}

await ws.send_json(msg)
try:
await ws.send_json(msg)
except Exception as e:
print(e)

ge_counts = 0
t_counts = 0
t_counts += 1

async def stats(running):

global t_counts
global ge_counts

while running.get():

await asyncio.sleep(2)

print(
f"Graph embeddings: {ge_counts:10d} Triples: {t_counts:10d}"
)

async def loader(running, ge_queue, t_queue, path, format, user, collection):

global t_counts
global ge_counts

if format == "json":

raise RuntimeError("Not implemented")
Expand All @@ -118,31 +141,59 @@ async def loader(running, ge_queue, t_queue, path, format, user, collection):

if unpacked[0] == "t":
qtype = t_queue
t_counts += 1
else:
if unpacked[0] == "ge":
qtype = ge_queue
ge_counts += 1

while running.get():

try:
await asyncio.wait_for(qtype.put(unpacked[1]), 0.5)

# Successful put message, move on
break

except:
# Hopefully it's TimeoutError. Annoying to match since
# it changed in 3.11.
continue

if not running.get(): break

running.stop()

# Put 'None' on end of queue to finish
while running.get():

try:
await asyncio.wait_for(t_queue.put(None), 1)

# Successful put message, move on
break

except:
# Hopefully it's TimeoutError. Annoying to match since
# it changed in 3.11.
continue

# Put 'None' on end of queue to finish
while running.get():

try:
await asyncio.wait_for(ge_queue.put(None), 1)

# Successful put message, move on
break

except:
# Hopefully it's TimeoutError. Annoying to match since
# it changed in 3.11.
continue

async def run(running, **args):

# Maxsize on queues reduces back-pressure so tg-load-kg-core doesn't
# grow to eat all memory
ge_q = asyncio.Queue(maxsize=500)
t_q = asyncio.Queue(maxsize=500)
ge_q = asyncio.Queue(maxsize=10)
t_q = asyncio.Queue(maxsize=10)

load_task = asyncio.create_task(
loader(
Expand Down Expand Up @@ -170,9 +221,12 @@ async def run(running, **args):

stats_task = asyncio.create_task(stats(running))

await load_task
await triples_task
await ge_task

running.stop()

await load_task
await stats_task

async def main(running):
Expand Down

0 comments on commit 6103127

Please sign in to comment.