Skip to content

Commit

Permalink
Fix deduplicating dependent tasks + tests
Browse files Browse the repository at this point in the history
  • Loading branch information
whimo committed May 3, 2024
1 parent 814cfc5 commit 21172b9
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 23 deletions.
6 changes: 4 additions & 2 deletions motleycrew/tasks/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,10 @@ def set_upstream(self, task: Task) -> Task:
if task is self:
raise TaskDependencyCycleError(f"Task {task.name} can not depend on itself")

self.upstream_tasks.append(task)
task.downstream_tasks.append(self)
if task not in self.upstream_tasks:
self.upstream_tasks.append(task)
if self not in task.downstream_tasks:
task.downstream_tasks.append(self)

return self

Expand Down
42 changes: 21 additions & 21 deletions tests/test_tasks/test_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,14 @@ def test_set_upstream_returns_self(self, task1, task2):
def test_set_upstream_sets_upstream(self, task1, task2):
task2.set_upstream(task1)

assert task1.upstream_tasks == set()
assert task2.upstream_tasks == {task1}
assert task1.upstream_tasks == []
assert task2.upstream_tasks == [task1]

def test_set_upstream_sets_downstreams(self, task1, task2):
task2.set_upstream(task1)

assert task1.downstream_tasks == {task2}
assert task2.downstream_tasks == set()
assert task1.downstream_tasks == [task2]
assert task2.downstream_tasks == []

def test_rshift_returns_left(self, task1, task2):
result = task1 >> task2
Expand All @@ -58,28 +58,28 @@ def test_rshift_returns_left(self, task1, task2):
def test_rshift_sets_downstream(self, task1, task2):
task1 >> task2

assert task1.downstream_tasks == {task2}
assert task2.downstream_tasks == set()
assert task1.downstream_tasks == [task2]
assert task2.downstream_tasks == []

def test_rshift_sets_upstream(self, task1, task2):
task1 >> task2

assert task1.upstream_tasks == set()
assert task2.upstream_tasks == {task1}
assert task1.upstream_tasks == []
assert task2.upstream_tasks == [task1]

def test_rshift_set_multiple_downstream(self, task1, task2, task3):
task1 >> [task2, task3]

assert task1.downstream_tasks == {task2, task3}
assert task2.downstream_tasks == set()
assert task3.downstream_tasks == set()
assert task1.downstream_tasks == [task2, task3]
assert task2.downstream_tasks == []
assert task3.downstream_tasks == []

def test_rshift_set_multiple_upstream(self, task1, task2, task3):
task1 >> [task2, task3]

assert task1.upstream_tasks == set()
assert task2.upstream_tasks == {task1}
assert task3.upstream_tasks == {task1}
assert task1.upstream_tasks == []
assert task2.upstream_tasks == [task1]
assert task3.upstream_tasks == [task1]

def test_sequence_on_left_returns_sequence(self, task1, task2, task3):
result = [task1, task2] >> task3
Expand All @@ -89,21 +89,21 @@ def test_sequence_on_left_returns_sequence(self, task1, task2, task3):
def test_sequence_on_left_sets_downstream(self, task1, task2, task3):
[task1, task2] >> task3

assert task1.downstream_tasks == {task3}
assert task2.downstream_tasks == {task3}
assert task3.downstream_tasks == set()
assert task1.downstream_tasks == [task3]
assert task2.downstream_tasks == [task3]
assert task3.downstream_tasks == []

def test_sequence_on_left_sets_upstream(self, task1, task2, task3):
[task1, task2] >> task3

assert task1.upstream_tasks == set()
assert task2.upstream_tasks == set()
assert task3.upstream_tasks == {task1, task2}
assert task1.upstream_tasks == []
assert task2.upstream_tasks == []
assert task3.upstream_tasks == [task1, task2]

def test_deduplicates(self, task1, task2):
task1 >> [task2, task2]

assert task1.downstream_tasks == {task2}
assert task1.downstream_tasks == [task2]

def test_error_on_direct_dependency_cycle(self, task1):
with pytest.raises(TaskDependencyCycleError):
Expand Down

0 comments on commit 21172b9

Please sign in to comment.