-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Naive gradient descent example #75
base: master
Are you sure you want to change the base?
Changes from all commits
bff3961
5c37fc2
4397bda
c280ccb
1c6d627
d3b5127
2d84812
bca18d3
29e98a8
080a640
0b9ad84
ce98ea0
b7c2acc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,132 @@ | ||
{-# LANGUAGE DataKinds #-} | ||
{-# LANGUAGE DeriveGeneric #-} | ||
{-# LANGUAGE FlexibleContexts #-} | ||
{-# LANGUAGE FlexibleInstances #-} | ||
{-# LANGUAGE MultiParamTypeClasses #-} | ||
{-# LANGUAGE RankNTypes #-} | ||
{-# LANGUAGE ScopedTypeVariables #-} | ||
{-# LANGUAGE StandaloneDeriving #-} | ||
{-# LANGUAGE TemplateHaskell #-} | ||
{-# LANGUAGE TypeApplications #-} | ||
{-# LANGUAGE TypeFamilies #-} | ||
{-# LANGUAGE TypeOperators #-} | ||
{-# LANGUAGE UndecidableInstances #-} | ||
|
||
module F | ||
( Input (..), | ||
Output (..), | ||
Param (..), | ||
XY (..), | ||
rosenbrock, | ||
dRosenbrock, | ||
wrap_rosenbrockF, | ||
wrap_dRosenbrockF, | ||
) | ||
where | ||
|
||
import qualified Categorifier.C.CExpr.Cat as C | ||
import Categorifier.C.CExpr.Cat.TargetOb (TargetOb, TargetObTC1) | ||
import Categorifier.C.CExpr.Types.Core (CExpr) | ||
import Categorifier.C.CTypes.CGeneric (CGeneric) | ||
import qualified Categorifier.C.CTypes.CGeneric as CG | ||
import Categorifier.C.CTypes.GArrays (GArrays) | ||
import Categorifier.C.KTypes.C (C) | ||
import Categorifier.C.KTypes.Function (kFunctionCall) | ||
import Categorifier.C.KTypes.KType1 (KType1) | ||
import qualified Categorifier.Categorify as Categorify | ||
import qualified Categorifier.Category as Category | ||
import Categorifier.Client (deriveHasRep) | ||
import Data.Int (Int32) | ||
import Data.Proxy (Proxy (..)) | ||
import Data.Reflection (Reifies) | ||
import Data.Word (Word64) | ||
import GHC.Generics (Generic) | ||
import Numeric.AD (grad) | ||
import Numeric.AD.Internal.Reverse (Reverse (Lift), Tape) | ||
|
||
data Param f = Param | ||
{ paramA :: f Double, | ||
paramB :: f Double | ||
} | ||
deriving (Generic) | ||
|
||
deriving instance Show (Param C) | ||
|
||
deriveHasRep ''Param | ||
|
||
instance CGeneric (Param f) | ||
|
||
instance GArrays f (Param f) | ||
|
||
type instance TargetOb (Param f) = Param (TargetObTC1 f) | ||
|
||
data XY f = XY | ||
{ xyX :: f Double, | ||
xyY :: f Double | ||
} | ||
deriving (Generic) | ||
|
||
deriving instance Show (XY C) | ||
|
||
deriveHasRep ''XY | ||
|
||
instance CGeneric (XY f) | ||
|
||
instance GArrays f (XY f) | ||
|
||
type instance TargetOb (XY f) = XY (TargetObTC1 f) | ||
|
||
data Input f = Input | ||
{ iParam :: Param f, | ||
iCoord :: XY f | ||
} | ||
deriving (Generic) | ||
|
||
deriving instance Show (Input C) | ||
|
||
deriveHasRep ''Input | ||
|
||
instance CGeneric (Input f) | ||
|
||
instance GArrays f (Input f) | ||
|
||
type instance TargetOb (Input f) = Input (TargetObTC1 f) | ||
|
||
newtype Output f = Output | ||
{oF :: f Double} | ||
deriving (Generic) | ||
|
||
deriving instance Show (Output C) | ||
|
||
deriveHasRep ''Output | ||
|
||
instance CGeneric (Output f) | ||
|
||
instance GArrays f (Output f) | ||
|
||
type instance TargetOb (Output f) = Output (TargetObTC1 f) | ||
|
||
rosenbrock :: Num a => (a, a) -> (a, a) -> a | ||
rosenbrock (a, b) (x, y) = (a - x) ^ 2 + b * (y - x ^ 2) ^ 2 | ||
|
||
dRosenbrock :: forall a. Num a => (a, a) -> (a, a) -> (a, a) | ||
dRosenbrock (a, b) (x, y) = | ||
let rosenbrock' :: forall s. Reifies s Tape => [Reverse s a] -> Reverse s a | ||
rosenbrock' [x', y'] = | ||
let a' = Lift a | ||
b' = Lift b | ||
in rosenbrock (a', b') (x', y') | ||
[dfdx, dfdy] = grad rosenbrock' [x, y] | ||
in (dfdx, dfdy) | ||
|
||
rosenbrockF :: KType1 f => Input f -> Output f | ||
rosenbrockF (Input (Param a b) (XY x y)) = Output $ rosenbrock (a, b) (x, y) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I feel like we can get rid of rosenbrockF :: KType1 f => (Param f, XY f) -> f Double
rosenbrockF (Param a b, XY x y) = rosenbrock (a, b) (x, y) I'd rather have this be |
||
|
||
dRosenbrockF :: forall f. (KType1 f) => Input f -> XY f | ||
dRosenbrockF (Input (Param a b) (XY x y)) = | ||
let (dfdx, dfdy) = dRosenbrock (a, b) (x, y) | ||
in XY dfdx dfdy | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similar change here dRosenbrockF :: forall f. (KType1 f) => (Param f, XY f) -> XY f
dRosenbrockF (Param a b, XY x y) = uncurry XY $ dRosenbrock (a, b) (x, y) |
||
|
||
$(Categorify.separately 'rosenbrockF [t|C.Cat|] [pure [t|C|]]) | ||
|
||
$(Categorify.separately 'dRosenbrockF [t|C.Cat|] [pure [t|C|]]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Orthogonal to this change, but I do want to put together an example that does something similar to this, without Categorify.expression @C.Cat (unD (Categorify.expression @ConCat.RAD dRosenbrockF)) to illustrate nested categorification (which may not actually work yet). |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
{-# LANGUAGE ForeignFunctionInterface #-} | ||
{-# LANGUAGE OverloadedStrings #-} | ||
{-# LANGUAGE TemplateHaskell #-} | ||
{-# LANGUAGE TypeApplications #-} | ||
|
||
module Main where | ||
|
||
import Categorifier.C.Codegen.FFI.TH (embedFunction) | ||
import Categorifier.C.KTypes.C (C (unsafeC)) | ||
import Categorifier.C.KTypes.KLiteral (kliteral) | ||
import Data.Foldable (traverse_) | ||
import F | ||
( Input (..), | ||
Output (..), | ||
Param (..), | ||
XY (..), | ||
dRosenbrock, | ||
rosenbrock, | ||
wrap_dRosenbrockF, | ||
wrap_rosenbrockF, | ||
) | ||
|
||
$(embedFunction "rosenbrockF" wrap_rosenbrockF) | ||
|
||
$(embedFunction "dRosenbrockF" wrap_dRosenbrockF) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same re: |
||
|
||
gamma :: Double | ||
gamma = 0.01 | ||
|
||
step :: | ||
((Double, Double) -> IO Double) -> | ||
((Double, Double) -> IO (Double, Double)) -> | ||
(Double, Double) -> | ||
IO (Double, Double) | ||
step _f df (x0, y0) = do | ||
(dfdx, dfdy) <- df (x0, y0) | ||
let (x1, y1) = (x0 - gamma * dfdx, y0 - gamma * dfdy) | ||
pure (x1, y1) | ||
|
||
iterateNM :: (Monad m) => Int -> (a -> m a) -> a -> m [a] | ||
iterateNM n f x0 = go n x0 id | ||
where | ||
go k x acc | ||
| k > 0 = do | ||
y <- f x | ||
go (k - 1) y (acc . (y :)) | ||
| otherwise = pure (acc []) | ||
|
||
main :: IO () | ||
main = do | ||
let (x0, y0) = (0.1, 0.4) | ||
-- pure haskell | ||
putStrLn "pure haskell" | ||
let f = pure . rosenbrock (1, 10) | ||
df = pure . dRosenbrock (1, 10) | ||
histH <- iterateNM 10 (step f df) (x0, y0) | ||
traverse_ print histH | ||
|
||
-- C | ||
putStrLn "codegen C" | ||
let g (x, y) = do | ||
Output z <- hs_rosenbrockF (Input (Param 1 10) (XY (kliteral x) (kliteral y))) | ||
pure (unsafeC z) | ||
dg (x, y) = do | ||
XY x' y' <- hs_dRosenbrockF (Input (Param 1 10) (XY (kliteral x) (kliteral y))) | ||
pure (unsafeC x', unsafeC y') | ||
histC <- iterateNM 10 (step g dg) (x0, y0) | ||
traverse_ print histC |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is a
Pair
type defined in categorifier, that seems better than[]
here, since it's not partial.