Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Separating evolution from representation #10

Open
organic-chemistry opened this issue Jul 7, 2020 · 1 comment
Open

Separating evolution from representation #10

organic-chemistry opened this issue Jul 7, 2020 · 1 comment

Comments

@organic-chemistry
Copy link

Hello, First thank you for this amazing post.
I tried to modify the model, to separate the evolution form the representation,
meaning that I have a function that evolve the state and at the end a function that use the
state evolved and compute a representation that compare it to the image to fit.
(I also changed the living_channel to the first channel but this worked fine).
However it seems that the gradient is not propagated to the weight of the evolution
layer.
Do you know why?

class CAModel(tf.keras.Model):

  def __init__(self, channel_n=CHANNEL_N, fire_rate=CELL_FIRE_RATE):
    super().__init__()
    self.channel_n = channel_n
    self.fire_rate = fire_rate

    input_with_gradient = tf.keras.Input(shape=(None,None,self.channel_n*3),
                                         name="gradient")
    current_state = tf.keras.Input(shape=(None,None,self.channel_n),
                                         name="current")
    
    evolution =  layers.Conv2D(self.channel_n, 1, activation=tf.nn.relu,
                               name="evolution")(input_with_gradient)

    representation = layers.Conv2D(3, 1, 
                                   activation=tf.nn.relu,
                                   name="representation")(current_state)
    self.model = tf.keras.Model(inputs=[current_state,
                                 input_with_gradient],
                                outputs=[evolution,representation], name="global")

    self(tf.zeros([1, 3, 3, channel_n]))  # dummy call to build the model

  @tf.function
  def perceive(self, x, angle=0.0):
    identify = np.float32([0, 1, 0])
    identify = np.outer(identify, identify)
    dx = np.outer([1, 2, 1], [-1, 0, 1]) / 8.0  # Sobel filter
    dy = dx.T
    c, s = tf.cos(angle), tf.sin(angle)
    kernel = tf.stack([identify, c*dx-s*dy, s*dx+c*dy], -1)[:, :, None, :]
    kernel = tf.repeat(kernel, self.channel_n, 2)
    y = tf.nn.depthwise_conv2d(x, kernel, [1, 1, 1, 1], 'SAME')
    return y

  @tf.function
  def call(self, x, fire_rate=None, angle=0.0, step_size=1.0):
    pre_life_mask = get_living_mask(x)

    y = self.perceive(x, angle)
    dx,representation = self.model([x,y])
    dx = dx*step_size
    if fire_rate is None:
      fire_rate = self.fire_rate
    update_mask = tf.random.uniform(tf.shape(x[:, :, :, :1])) <= fire_rate
    x += dx * tf.cast(update_mask, tf.float32)

    post_life_mask = get_living_mask(x)
    life_mask = pre_life_mask & post_life_mask
    casted_life_mask = tf.cast(life_mask, tf.float32)

    # a representation is an alpha channel at the top and 
    # rgb channel

    return x * casted_life_mask , tf.concat([casted_life_mask,
                                             representation * casted_life_mask],
                                            axis=-1)

and for the evolution of the state:

for i in tf.range(iter_n):
      x = ca(x)
      x,representation = x
loss = tf.reduce_mean(loss_f(representation,img))
@organic-chemistry
Copy link
Author

organic-chemistry commented Jul 7, 2020

So I changed the representation layer to

representation = layers.Conv2D(3, 1, 
                                   activation=None,
                                    kernel_initializer=tf.zeros_initializer,
                                   name="representation")(current_state)

And checked that the gradient of the evolution layer is not null,
but it seem that the loss decrease at the beginning and then plateau at a high value,
where the image does not look at all like the target.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant