Skip to content

Commit

Permalink
sync with fix/write-batching-2
Browse files Browse the repository at this point in the history
  • Loading branch information
dmaskasky committed Dec 17, 2024
1 parent cc72965 commit f9c3966
Showing 1 changed file with 115 additions and 155 deletions.
270 changes: 115 additions & 155 deletions src/vanilla/store.ts
Original file line number Diff line number Diff line change
Expand Up @@ -166,20 +166,24 @@ const addDependency = <Value>(
type Batch = Readonly<{
/** Atom dependents map */
D: Map<AnyAtom, Set<AnyAtom>>
/** High priority functions */
H: Set<() => void>
/** Medium priority functions */
M: Set<() => void>
/** Low priority functions */
L: Set<() => void>
}>

const createPending = (): Pending => [
/** dependents */
new Map(),
/** atomStates */
new Map(),
/** functions */
new Set(),
]
const createBatch = (): Batch => ({
D: new Map(),
H: new Set(),
M: new Set(),
L: new Set(),
})

const addBatchFuncHigh = (batch: Batch, fn: () => void) => {
batch.H.add(fn)
}

const addBatchFuncMedium = (batch: Batch, fn: () => void) => {
batch.M.add(fn)
Expand Down Expand Up @@ -222,6 +226,30 @@ const copySetAndClear = <T>(origSet: Set<T>): Set<T> => {
return newSet
}

const flushBatch = (batch: Batch) => {
let error: AnyError
let hasError = false
const call = (fn: () => void) => {
try {
fn()
} catch (e) {
if (!hasError) {
error = e
hasError = true
}
}
}
while (batch.M.size || batch.L.size) {
batch.D.clear()
copySetAndClear(batch.H).forEach(call)
copySetAndClear(batch.M).forEach(call)
copySetAndClear(batch.L).forEach(call)
}
if (hasError) {
throw error
}
}

// internal & unstable type
type StoreArgs = readonly [
getAtomState: <Value>(atom: Atom<Value>) => AtomState<Value>,
Expand Down Expand Up @@ -273,33 +301,6 @@ const buildStore = (
debugMountedAtoms = new Set()
}

const flushPending = (pending: Pending) => {
let error: AnyError
let hasError = false
const call = (fn: () => void) => {
try {
fn()
} catch (e) {
if (!hasError) {
error = e
hasError = true
}
}
}
while (pending[0].size || pending[1].size || pending[2].size) {
recomputeDependents(pending, new Set(pending[0].keys()))
const atomStates = new Set(pending[1].values())
pending[1].clear()
const functions = new Set(pending[2])
pending[2].clear()
atomStates.forEach((atomState) => atomState.m?.l.forEach(call))
functions.forEach(call)
}
if (hasError) {
throw error
}
}

const setAtomStateValueOrPromise = (
atom: AnyAtom,
atomState: AtomState,
Expand All @@ -314,11 +315,11 @@ const buildStore = (
addPendingPromiseToDependency(atom, valueOrPromise, getAtomState(a))
}
atomState.v = valueOrPromise
delete atomState.e
} else {
atomState.v = valueOrPromise
delete atomState.e
}
delete atomState.e
delete atomState.x
if (!hasPrevValue || !Object.is(prevValue, atomState.v)) {
++atomState.n
if (pendingPromise) {
Expand Down Expand Up @@ -347,7 +348,7 @@ const buildStore = (
([a, n]) =>
// Recursively, read the atom state of the dependency, and
// check if the atom epoch number is unchanged
readAtomState(pending, a).n === n,
readAtomState(batch, a).n === n,
)
) {
return atomState
Expand All @@ -370,7 +371,7 @@ const buildStore = (
return returnAtomValue(aState)
}
// a !== atom
const aState = readAtomState(pending, a)
const aState = readAtomState(batch, a)
try {
return returnAtomValue(aState)
} finally {
Expand Down Expand Up @@ -431,6 +432,7 @@ const buildStore = (
} catch (error) {
delete atomState.v
atomState.e = error
delete atomState.x
++atomState.n
return atomState
} finally {
Expand All @@ -441,144 +443,102 @@ const buildStore = (
const readAtom = <Value>(atom: Atom<Value>): Value =>
returnAtomValue(readAtomState(undefined, atom))

const markRecomputePending = (
pending: Pending,
atom: AnyAtom,
atomState: AtomState,
) => {
addPendingAtom(pending, atom, atomState)
if (isPendingRecompute(atom)) {
return
}
const dependents = getAllDependents(pending, [atom])
for (const [dependent] of dependents) {
getAtomState(dependent).x = true
}
}

const markRecomputeComplete = (
pending: Pending,
atom: AnyAtom,
atomState: AtomState,
) => {
delete atomState.x
pending[0].delete(atom)
}

const isPendingRecompute = (atom: AnyAtom) => getAtomState(atom).x

const getDependents = (pending: Pending, a: AnyAtom, aState: AtomState) => {
return new Set<AnyAtom>([
...(aState.m?.t || []),
...aState.p,
...(getPendingDependents(pending, a) || []),
])
}

/** @returns map of all dependents or dependencies (deep) of the root atoms */
const getDeep = (
/** function to get immediate dependents or dependencies of the atom */
getDeps: (a: AnyAtom, aState: AtomState) => Iterable<AnyAtom>,
rootAtoms: Iterable<AnyAtom>,
) => {
const visited = new Map<AnyAtom, Set<AnyAtom>>()
const stack: AnyAtom[] = Array.from(rootAtoms)
while (stack.length > 0) {
const a = stack.pop()!
const getMountedOrBatchDependents = <Value>(
batch: Batch,
atom: Atom<Value>,
atomState: AtomState<Value>,
): Map<AnyAtom, AtomState> => {
const dependents = new Map<AnyAtom, AtomState>()
for (const a of atomState.m?.t || []) {
const aState = getAtomState(a)
if (visited.has(a)) {
continue
}
const deps = new Set(getDeps(a, aState))
visited.set(a, deps)
for (const d of deps) {
if (!visited.has(d)) {
stack.push(d)
}
if (aState.m) {
dependents.set(a, aState)
}
}
return visited
for (const atomWithPendingPromise of atomState.p) {
dependents.set(
atomWithPendingPromise,
getAtomState(atomWithPendingPromise),
)
}
getBatchAtomDependents(batch, atom)?.forEach((dependent) => {
dependents.set(dependent, getAtomState(dependent))
})
return dependents
}

const getAllDependents = (pending: Pending, atoms: Iterable<AnyAtom>) =>
getDeep((a, aState) => getDependents(pending, a, aState), atoms)

// This is a topological sort via depth-first search, slightly modified from
// what's described here for simplicity and performance reasons:
// https://en.wikipedia.org/wiki/Topological_sorting#Depth-first_search
const getSortedDependents = (
pending: Pending,
rootAtoms: Iterable<AnyAtom>,
const recomputeDependents = <Value>(
batch: Batch,
atom: Atom<Value>,
atomState: AtomState<Value>,
) => {
const atomMap = getAllDependents(pending, rootAtoms)
const sorted: AnyAtom[] = []
// Step 1: traverse the dependency graph to build the topsorted atom list
// We don't bother to check for cycles, which simplifies the algorithm.
// This is a topological sort via depth-first search, slightly modified from
// what's described here for simplicity and performance reasons:
// https://en.wikipedia.org/wiki/Topological_sorting#Depth-first_search
const topSortedReversed: [
atom: AnyAtom,
atomState: AtomState,
epochNumber: number,
][] = []
const visiting = new Set<AnyAtom>()
const visited = new Set<AnyAtom>()
// Visit the root atoms. These are the only atoms in the dependency graph
// Visit the root atom. This is the only atom in the dependency graph
// without incoming edges, which is one reason we can simplify the algorithm
const stack: [a: AnyAtom, dependents: Set<AnyAtom>][] = []
for (const a of rootAtoms) {
if (atomMap.has(a)) {
stack.push([a, atomMap.get(a)!])
}
}
const stack: [a: AnyAtom, aState: AtomState][] = [[atom, atomState]]
while (stack.length > 0) {
const [a, dependents] = stack[stack.length - 1]!
const [a, aState] = stack[stack.length - 1]!
if (visited.has(a)) {
// All dependents have been processed, now process this atom
stack.pop()
continue
}
if (visiting.has(a)) {
// The algorithm calls for pushing onto the front of the list.
// For performance we push on the end, and will reverse the order later.
sorted.push(a)
// The algorithm calls for pushing onto the front of the list. For
// performance, we will simply push onto the end, and then will iterate in
// reverse order later.
topSortedReversed.push([a, aState, aState.n])
// Atom has been visited but not yet processed
visited.add(a)
// Mark atom dirty
aState.x = true
stack.pop()
continue
}
visiting.add(a)
// Push unvisited dependents onto the stack
for (const d of dependents) {
if (a !== d && !visiting.has(d) && atomMap.has(d)) {
stack.push([d, atomMap.get(d)!])
for (const [d, s] of getMountedOrBatchDependents(batch, a, aState)) {
if (a !== d && !visiting.has(d)) {
stack.push([d, s])
}
}
}
return sorted.reverse()
}

const recomputeDependents = (pending: Pending, rootAtoms: Set<AnyAtom>) => {
if (rootAtoms.size === 0) {
return
}
const hasChangedDeps = (aState: AtomState) =>
Array.from(aState.d.keys()).some((d) => rootAtoms.has(d))
// traverse the dependency graph to build the topsorted atom list
for (const a of getSortedDependents(pending, rootAtoms)) {
// use the topsorted atom list to recompute all affected atoms
// Track what's changed, so that we can short circuit when possible
const aState = getAtomState(a)
const prevEpochNumber = aState.n
if (isPendingRecompute(a) || hasChangedDeps(aState)) {
readAtomState(pending, a)
mountDependencies(pending, a, aState)
if (prevEpochNumber !== aState.n) {
markRecomputePending(pending, a, aState)
// Step 2: use the topSortedReversed atom list to recompute all affected atoms
// Track what's changed, so that we can short circuit when possible
addBatchFuncHigh(batch, () => {
const changedAtoms = new Set<AnyAtom>([atom])
for (let i = topSortedReversed.length - 1; i >= 0; --i) {
const [a, aState, prevEpochNumber] = topSortedReversed[i]!
let hasChangedDeps = false
for (const dep of aState.d.keys()) {
if (dep !== a && changedAtoms.has(dep)) {
hasChangedDeps = true
break
}
}
if (hasChangedDeps) {
readAtomState(batch, a)
mountDependencies(batch, a, aState)
if (prevEpochNumber !== aState.n) {
registerBatchAtom(batch, a, aState)
changedAtoms.add(a)
}
}
delete aState.x
}
markRecomputeComplete(pending, a, aState)
}
}

const recomputeDependencies = (pending: Pending, a: AnyAtom) => {
if (!isPendingRecompute(a)) {
return
}
const getDependencies = (_: unknown, aState: AtomState) => aState.d.keys()
const dependencies = Array.from(getDeep(getDependencies, [a]).keys())
const dirtyDependencies = new Set(dependencies.filter(isPendingRecompute))
recomputeDependents(pending, dirtyDependencies)
})
}

const writeAtomState = <Value, Args extends unknown[], Result>(
Expand All @@ -587,10 +547,8 @@ const buildStore = (
...args: Args
): Result => {
let isSync = true
const getter: Getter = <V>(a: Atom<V>) => {
recomputeDependencies(pending, atom)
return returnAtomValue(readAtomState(pending, a))
}
const getter: Getter = <V>(a: Atom<V>) =>
returnAtomValue(readAtomState(batch, a))
const setter: Setter = <V, As extends unknown[], R>(
a: WritableAtom<V, As, R>,
...args: As
Expand All @@ -607,7 +565,8 @@ const buildStore = (
setAtomStateValueOrPromise(a, aState, v)
mountDependencies(batch, a, aState)
if (prevEpochNumber !== aState.n) {
markRecomputePending(pending, a, aState)
registerBatchAtom(batch, a, aState)
recomputeDependents(batch, a, aState)
}
return undefined as R
} else {
Expand Down Expand Up @@ -792,7 +751,8 @@ const buildStore = (
setAtomStateValueOrPromise(atom, atomState, value)
mountDependencies(batch, atom, atomState)
if (prevEpochNumber !== atomState.n) {
markRecomputePending(pending, atom, atomState)
registerBatchAtom(batch, atom, atomState)
recomputeDependents(batch, atom, atomState)
}
}
}
Expand Down

0 comments on commit f9c3966

Please sign in to comment.