Skip to content

Commit

Permalink
refactoring postprocessing
Browse files Browse the repository at this point in the history
  • Loading branch information
jtristan committed May 16, 2024
1 parent b5f2c29 commit 079b482
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 33 deletions.
2 changes: 1 addition & 1 deletion SampCert.lean
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@ Released under Apache 2.0 license as described in the file LICENSE.
Authors: Jean-Baptiste Tristan
-/

import SampCert.DifferentialPrivacy.ZeroConcentrated.Queries.BoundedMean.Code
import SampCert.DifferentialPrivacy.ZeroConcentrated.Queries.BoundedMean.Basic
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,17 @@ open Classical Nat Int Real ENNReal MeasureTheory Measure

namespace SLang

theorem foo (f : U → ℤ) (g : U → ENNReal) (x : ℤ) :
variable {T : Type}
variable [t1 : MeasurableSpace T]
variable [t2 : MeasurableSingletonClass T]

variable {U : Type}
variable [m2 : MeasurableSpace U]
variable [count : Countable U]
variable [disc : DiscreteMeasurableSpace U]
variable [Inhabited U]

theorem condition_to_subset (f : U → ℤ) (g : U → ENNReal) (x : ℤ) :
(∑' a : U, if x = f a then g a else 0) = ∑' a : { a | x = f a }, g a := by
have A := @tsum_split_ite U (fun a : U => x = f a) g (fun _ => 0)
simp only [decide_eq_true_eq, tsum_zero, add_zero] at A
Expand All @@ -22,14 +32,9 @@ theorem foo (f : U → ℤ) (g : U → ENNReal) (x : ℤ) :
simp
rw [B]

variable {T : Type}
variable [m1 : MeasurableSpace T]
variable [m2 : MeasurableSingletonClass T]
variable [m3: MeasureSpace T]

theorem Integrable_rpow (f : T → ℝ) (nn : ∀ x : T, 0 ≤ f x) (μ : Measure T) (α : ENNReal) (mem : Memℒp f α μ) (h1 : α ≠ 0) (h2 : α ≠ ⊤) :
MeasureTheory.Integrable (fun x : T => (f x) ^ α.toReal) μ := by
have X := @MeasureTheory.Memℒp.integrable_norm_rpow T ℝ m1 μ _ f α mem h1 h2
have X := @MeasureTheory.Memℒp.integrable_norm_rpow T ℝ t1 μ _ f α mem h1 h2
revert X
conv =>
left
Expand All @@ -51,7 +56,7 @@ theorem Integrable_rpow (f : T → ℝ) (nn : ∀ x : T, 0 ≤ f x) (μ : Measur
. rw [← hasFiniteIntegral_norm_iff]
simp [X]

theorem bar (f : T → ℝ) (q : PMF T) (α : ℝ) (h : 1 < α) (h2 : ∀ x : T, 0 ≤ f x) (mem : Memℒp f (ENNReal.ofReal α) (PMF.toMeasure q)) :
theorem Renyi_Jensen (f : T → ℝ) (q : PMF T) (α : ℝ) (h : 1 < α) (h2 : ∀ x : T, 0 ≤ f x) (mem : Memℒp f (ENNReal.ofReal α) (PMF.toMeasure q)) :
((∑' x : T, (f x) * (q x).toReal)) ^ α ≤ (∑' x : T, (f x) ^ α * (q x).toReal) := by

conv =>
Expand Down Expand Up @@ -91,7 +96,7 @@ theorem bar (f : T → ℝ) (q : PMF T) (α : ℝ) (h : 1 < α) (h2 : ∀ x : T,
simp at h''
have C : @IsClosed ℝ UniformSpace.toTopologicalSpace (Set.Ici 0) := by
exact isClosed_Ici
have D := @ConvexOn.map_integral_le T ℝ m1 _ _ _ (PMF.toMeasure q) (Set.Ici 0) f (fun (x : ℝ) => x ^ α) (PMF.toMeasure.isProbabilityMeasure q) A B C
have D := @ConvexOn.map_integral_le T ℝ t1 _ _ _ (PMF.toMeasure q) (Set.Ici 0) f (fun (x : ℝ) => x ^ α) (PMF.toMeasure.isProbabilityMeasure q) A B C
simp at D
apply D
. exact MeasureTheory.ae_of_all (PMF.toMeasure q) h2
Expand All @@ -104,7 +109,7 @@ theorem bar (f : T → ℝ) (q : PMF T) (α : ℝ) (h : 1 < α) (h2 : ∀ x : T,
apply lt_trans zero_lt_one h
have Y : ENNReal.ofReal α ≠ ⊤ := by
simp
have Z := @Integrable_rpow T m1 f h2 (PMF.toMeasure q) (ENNReal.ofReal α) mem X Y
have Z := @Integrable_rpow T t1 f h2 (PMF.toMeasure q) (ENNReal.ofReal α) mem X Y
rw [toReal_ofReal] at Z
. exact Z
. apply le_of_lt
Expand All @@ -114,7 +119,7 @@ theorem bar (f : T → ℝ) (q : PMF T) (α : ℝ) (h : 1 < α) (h2 : ∀ x : T,
apply lt_trans zero_lt_one h
have Y : ENNReal.ofReal α ≠ ⊤ := by
simp
have Z := @Integrable_rpow T m1 f h2 (PMF.toMeasure q) (ENNReal.ofReal α) mem X Y
have Z := @Integrable_rpow T t1 f h2 (PMF.toMeasure q) (ENNReal.ofReal α) mem X Y
rw [toReal_ofReal] at Z
. exact Z
. apply le_of_lt
Expand All @@ -123,11 +128,6 @@ theorem bar (f : T → ℝ) (q : PMF T) (α : ℝ) (h : 1 < α) (h2 : ∀ x : T,
rw [one_le_ofReal]
apply le_of_lt h

variable {U : Type}
variable [m2 : MeasurableSpace U] -- [m2' : MeasurableSingletonClass U]
variable [count : Countable U]
variable [disc : DiscreteMeasurableSpace U]

def δ (nq : SLang U) (f : U → ℤ) (a : ℤ) : {n : U | a = f n} → ENNReal := fun x : {n : U | a = f n} => nq x * (∑' (x : {n | a = f n}), nq x)⁻¹

theorem δ_normalizes (nq : SLang U) (f : U → ℤ) (a : ℤ) (h1 : ∑' (i : ↑{n | a = f n}), nq ↑i ≠ 0) (h2 : ∑' (i : ↑{n | a = f n}), nq ↑i ≠ ⊤) :
Expand Down Expand Up @@ -193,7 +193,7 @@ theorem ENNReal.tsum_fiberwise (p : T → ENNReal) (f : T → ℤ) :
apply Summable.hasSum
exact ENNReal.summable

theorem quux (p : T → ENNReal) (f : T → ℤ) :
theorem fiberwisation (p : T → ENNReal) (f : T → ℤ) :
(∑' i : T, p i)
= ∑' (x : ℤ), if {a : T | x = f a} = {} then 0 else ∑'(i : {a : T | x = f a}), p i := by
rw [← ENNReal.tsum_fiberwise p f]
Expand All @@ -220,7 +220,7 @@ theorem quux (p : T → ENNReal) (f : T → ℤ) :

theorem convergent_subset {p : T → ENNReal} (f : T → ℤ) (conv : ∑' (x : T), p x ≠ ⊤) :
∑' (x : { y : T| x = f y }), p x ≠ ⊤ := by
rw [← foo]
rw [← condition_to_subset]
have A : (∑' (y : T), if x = f y then p y else 0) ≤ ∑' (x : T), p x := by
apply tsum_le_tsum
. intro i
Expand Down Expand Up @@ -275,7 +275,7 @@ theorem DPostPocess_pre {nq : List T → SLang U} {ε₁ ε₂ : ℕ+} (h : DP n
rw [@RenyiDivergenceExpectation _ (nq l₁) (nq l₂) _ h1 (nn l₂) (nts l₂)]

-- Shuffle the sum
rw [quux (fun x => (nq l₁ x / nq l₂ x) ^ α * nq l₂ x) f]
rw [fiberwisation (fun x => (nq l₁ x / nq l₂ x) ^ α * nq l₂ x) f]

apply ENNReal.tsum_le_tsum

Expand All @@ -284,7 +284,7 @@ theorem DPostPocess_pre {nq : List T → SLang U} {ε₁ ε₂ : ℕ+} (h : DP n
-- Get rid of elements with probability 0 in the pushforward
split
. rename_i empty
rw [foo]
rw [condition_to_subset]
have ZE : (∑' (x_1 : ↑{n | i = f n}), nq l₁ ↑x_1) = 0 := by
simp
intro a H
Expand Down Expand Up @@ -373,8 +373,8 @@ theorem DPostPocess_pre {nq : List T → SLang U} {ε₁ ε₂ : ℕ+} (h : DP n
intro a
apply ENNReal.ne_top_of_tsum_ne_top S3

rw [foo]
rw [foo]
rw [condition_to_subset]
rw [condition_to_subset]

-- Introduce Q(f⁻¹ i)
let κ := ∑' x : {n : U | i = f n}, nq l₂ x
Expand Down Expand Up @@ -478,7 +478,7 @@ theorem DPostPocess_pre {nq : List T → SLang U} {ε₁ ε₂ : ℕ+} (h : DP n
apply S3


have Jensen's := @bar {n : U | i = f n} Subtype.instMeasurableSpace Subtype.instMeasurableSingletonClass (fun a => (nq l₁ a / nq l₂ a).toReal) (δpmf (nq l₂) f i (MasterZero l₂) (MasterRW l₂)) α h1 P5 XXX
have Jensen's := @Renyi_Jensen {n : U | i = f n} Subtype.instMeasurableSpace Subtype.instMeasurableSingletonClass (fun a => (nq l₁ a / nq l₂ a).toReal) (δpmf (nq l₂) f i (MasterZero l₂) (MasterRW l₂)) α h1 P5 XXX
clear P5

have P6 : 0 ≤ (∑' (x : ↑{n | i = f n}), nq l₂ ↑x).toReal := by
Expand Down Expand Up @@ -556,11 +556,6 @@ theorem DPostPocess_pre {nq : List T → SLang U} {ε₁ ε₂ : ℕ+} (h : DP n
simp only [one_mul]

rw [ENNReal.tsum_mul_right]
have H1 : 0 ≤ ∑' (x : ↑{n | i = f n}), (nq l₁ ↑x).toReal := by
apply tsum_nonneg
simp
have H2 : 0 ≤ (∑' (a : ↑{a | i = f a}), nq l₂ ↑a)⁻¹.toReal := by
apply toReal_nonneg
have H4 : (∑' (a : ↑{a | i = f a}), nq l₂ ↑a)⁻¹ ≠ ⊤ := by
apply inv_ne_top.mpr
simp
Expand Down Expand Up @@ -600,8 +595,6 @@ theorem tsum_ne_zero_of_ne_zero {T : Type} [Inhabited T] (f : T → ENNReal) (h
have B := CONTRA default
contradiction

variable [Inhabited U]

theorem DPPostProcess {nq : List T → SLang U} {ε₁ ε₂ : ℕ+} (h : DP nq ((ε₁ : ℝ) / ε₂)) (nn : NonZeroNQ nq) (nt : NonTopRDNQ nq) (nts : NonTopNQ nq) (conv : NonTopSum nq) (f : U → ℤ) :
DP (PostProcess nq f) ((ε₁ : ℝ) / ε₂) := by
simp [PostProcess, DP, RenyiDivergence]
Expand Down Expand Up @@ -690,5 +683,4 @@ theorem DPPostProcess {nq : List T → SLang U} {ε₁ ε₂ : ℕ+} (h : DP nq
rw [lt_top_iff_ne_top] at Z
contradiction


end SLang
2 changes: 1 addition & 1 deletion Tests/testing-kolmogorov-discretegaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def test_kolmogorov_dist(N, sigma2, with_plots=False):
print("* Calling the 'test_kolmogorov_dist' function with N=1000 and location parameter sigma^2=10 (without plots):")
# How to use the "test_kolmogorov_dist" function: on N=10000 samples, with sigma^2 = 10 (no plots)
diff = test_kolmogorov_dist(N,sig2)
if diff < 0.01:
if diff < 0.02:
print("Test passed!")
exit(0)
else:
Expand Down

0 comments on commit 079b482

Please sign in to comment.