Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multiple outputs from a keras model #36116

Closed
Bidski opened this issue Jan 22, 2020 · 13 comments
Closed

Multiple outputs from a keras model #36116

Bidski opened this issue Jan 22, 2020 · 13 comments
Assignees
Labels
comp:keras Keras related issues stale This label marks the issue/pr stale - to be closed automatically if no activity stat:awaiting response Status - Awaiting response from author TF 2.1 for tracking issues in 2.1 release type:bug Bug

Comments

@Bidski
Copy link

Bidski commented Jan 22, 2020

Please make sure that this is a bug. As per our GitHub Policy, we only address code/doc bugs, performance issues, feature requests and build/installation issues on GitHub. tag:bug_template

System information

  • Have I written custom code (as opposed to using a stock example script provided in TensorFlow): yes
  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Ubuntu 18.04
  • Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if the issue happens on mobile device: N/A
  • TensorFlow installed from (source or binary): Source
  • TensorFlow version (use command below): 1.15.0-rc1-11276-gc9f7f636eb 2.1.0
  • Python version: 3.6.9
  • Bazel version (if compiling from source): 2.0.0
  • GCC/Compiler version (if compiling from source): gcc-7
  • CUDA/cuDNN version: 10.1/7.6.3
  • GPU model and memory: GTX1080Ti 11GB

You can collect some of this information using our environment capture
script
You can also obtain the TensorFlow version with: 1. TF 1.0: python -c "import tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)" 2. TF 2.0: python -c "import tensorflow as tf; print(tf.version.GIT_VERSION, tf.version.VERSION)"

Describe the current behavior
I have a custom (keras) CNN model as well as a custom loss function.

The model has two inputs at one resolution and multiple (6) outputs at different resolutions (each output has a different resolution).

The dataset, from a TFRecord file, has the 2 image inputs and 1 ground truth image as an output.

The loss function expects to receive the single ground truth image as y_true and the 6 outputs in a list as y_pred and will then calculate the loss value based on this.

With this scenario, I get the following error

ValueError: Error when checking model target: the list of Numpy arrays that you are passing to your model is not the size the model expected. Expected to see 6 array(s), for inputs ['output_1', 'output_2', 'output_3', 'output_4', 'output_5', 'output_6'] but instead got the following list of 1 arrays: [<tf.Tensor 'args_5:0' shape=(None, 384, 768, 1) dtype=float32>]...

If I modify my dataset loading code so that it resizes the ground truth image into a list of images with appropriate resolutions to match my networks output, I get the following error

tensorflow.python.framework.errors_impl.InvalidArgumentError: Shapes of all inputs must match: values[0].shape = [270,480,1] != values[1].shape = [135,240,1]
	 [[{{node packed}}]]

Describe the expected behavior
I expect that TF/keras would allow at least one of these scenarios.

Is there an accepted way to handle this sort of situation?

@ravikyram ravikyram self-assigned this Jan 22, 2020
@ravikyram ravikyram added the TF 2.1 for tracking issues in 2.1 release label Jan 22, 2020
@ravikyram
Copy link
Contributor

@Bidski

Can you please provide colab link or simple standalone code to reproduce the issue in our environment. It helps us in localizing the issue faster.Thanks!

@ravikyram ravikyram added the stat:awaiting response Status - Awaiting response from author label Jan 22, 2020
@Bidski
Copy link
Author

Bidski commented Jan 22, 2020

import tensorflow as tf


@tf.function
def parse_entry(entry):
    feature_description = {
        "image_l": tf.io.FixedLenFeature(shape=[], dtype=tf.string),
        "image_r": tf.io.FixedLenFeature(shape=[], dtype=tf.string),
        "disparity_l": tf.io.FixedLenSequenceFeature(shape=[], dtype=tf.float32, allow_missing=True),
    }

    example = tf.io.parse_single_example(entry, feature_description)

    example["image_l"] = tf.io.decode_image(example["image_l"], channels=0, dtype=tf.dtypes.uint8)
    example["image_r"] = tf.io.decode_image(example["image_r"], channels=0, dtype=tf.dtypes.uint8)
    example["disparity_l"] = tf.reshape(example["disparity_l"], (540, 960, 1))

    return (example["image_l"], example["image_r"]), example["disparity_l"]


@tf.function
def normalise(x, y):
    img_l = tf.image.convert_image_dtype(x[0], dtype=tf.dtypes.float32, name="convert1")
    img_r = tf.image.convert_image_dtype(x[1], dtype=tf.dtypes.float32, name="convert2")

    return (img_l, img_r), y


@tf.function
def fixup_shape(x, y):
    x[0].set_shape([540, 960, 3])
    x[1].set_shape([540, 960, 3])
    y.set_shape([540, 960, 1])

    return x, y


@tf.function
def scale_output_resolution(x, y):
    if False:
        disparities = [
            tf.image.resize(
                y,
                size=(tf.math.divide(540, tf.math.pow(2, n)), tf.math.divide(960, tf.math.pow(2, n))),
                method=tf.image.ResizeMethod.NEAREST_NEIGHBOR,
            )
            for n in range(1, 3)
        ]
    else:
        if False:
            disparities = [y for _ in range(1, 3)]
        else:
            disparities = y

    return x, disparities


def create_dataset(path):
    dataset = tf.data.TFRecordDataset(path)
    dataset = dataset.map(parse_entry)
    dataset = dataset.map(normalise)
    dataset = dataset.map(fixup_shape)
    dataset = dataset.map(scale_output_resolution)
    dataset = dataset.batch(batch_size=1, drop_remainder=False)

    return dataset


class TestModel(tf.keras.Model):
    def __init__(self, **kwargs):
        super(TestModel, self).__init__(**kwargs)

        self.conv1 = tf.keras.layers.Conv2D(filters=2, kernel_size=(7, 7), strides=(2, 2), padding="same", name="conv1")
        self.conv2 = tf.keras.layers.Conv2D(filters=1, kernel_size=(7, 7), strides=(2, 2), padding="same", name="conv2")

    def call(self, inputs):
        x = self.conv1(tf.concat(inputs, axis=-1, name="concat"))
        y = self.conv2(x)

        return [x, y]


if __name__ == "__main__":
    tf.config.experimental_run_functions_eagerly(True)

    train_dataset = create_dataset(sys.argv[1])
    valid_dataset = create_dataset(sys.argv[2])

    model = TestModel()
    model.compile(
        optimizer=tf.keras.optimizers.Adam(),
        loss=tf.keras.losses.MeanAbsoluteError(),
        metrics=[tf.keras.metrics.Accuracy],
    )
    model.fit(x=train_dataset, validation_data=valid_dataset, validation_steps=None, validation_freq=1, epochs=1)

This code generates a slightly different error when resizing the ground truth images

ValueError: Value [<tf.Tensor 'resize/Squeeze:0' shape=(270, 480, 1) dtype=float32>, <tf.Tensor 'resize_1/Squeeze:0' shape=(135, 240, 1) dtype=float32>] is not convertible to a tensor with dtype <dtype: 'float32'> and shape (2, None, None, 1).

but otherwise this code behaves similarly to what I described above.

Play with the if statement conditions in the scale_output_resolution function see the different behaviours.

Run the code as python ./test.py ./train.tfrecord ./valid.tfrecord.

Here are some links to tfrecord files to test.
train.tfrecord
valid.tfrecord

@tensorflowbutler tensorflowbutler removed the stat:awaiting response Status - Awaiting response from author label Jan 23, 2020
@ravikyram ravikyram added the comp:keras Keras related issues label Jan 23, 2020
@ravikyram
Copy link
Contributor

@Bidski

I have tried on colab with TF version 2.1.0-rc2 and i am seeing different error message.Please, find the gist here. Thanks!

@ravikyram ravikyram added the stat:awaiting response Status - Awaiting response from author label Jan 23, 2020
@Bidski
Copy link
Author

Bidski commented Jan 23, 2020

@ravikyram that looks to be the same error message that I reported originally?

@ravikyram ravikyram added the type:bug Bug label Jan 23, 2020
@tensorflowbutler tensorflowbutler removed the stat:awaiting response Status - Awaiting response from author label Jan 23, 2020
@Selotkin
Copy link

Selotkin commented Feb 25, 2020

@Bidski
I had a similar issue. I have two feature arrays as inputs to the network and an output vector.
After calling:
for f0, f1 in train_ds.take(1): print(f0.shape) print(f1.shape)
I get output:
(16, 2, 1000, 3) (16, 9)

However, model.fit returns the same error

@lzmax888
Copy link

I also have this issue when I call 'predict' of a model with multiple outputs.

Reproduction:
The following code works perfectly in 2.0, while it gets the error 'ValueError: Error when checking model....' with tf 2.1.
https://colab.research.google.com/drive/1hMLd5-r82FrnFnBub-B-fVW78Px4KPX1

Anyone can fix it?
Thanks.

@Strateus
Copy link

Strateus commented Apr 1, 2020

Same here: i extract losses from several layers, but loss function is the same for me. Compilation succeeds, but .fit fails:

ValueError: Error when checking model target: the list of Numpy arrays that you are passing to your model is not the size the model expected. Expected to see 11 array(s), for inputs [...layer names deleted...] but instead got the following list of 1 arrays: [<tf.Tensor 'args_1:0' shape=(None, None) dtype=float64>]

p.s. .predict might be a different issue, it works fine for me.

@Strateus
Copy link

Strateus commented Apr 1, 2020

Here is reproducible example:

import tensorflow as tf
import numpy as np

num_examples = 1000
input_shape = (10,)
loss_dim = 3
num_losses = 2

features = np.random.random(size=[num_examples] + list(input_shape))
labels = np.random.random(size=[num_examples, num_losses, loss_dim])

inputs = tf.keras.Input(shape=input_shape, name='features')
l1 = tf.keras.layers.Dense(loss_dim, activation='relu', name='l1')(inputs)
l2 = tf.keras.layers.Dense(loss_dim, activation='relu', name='l2')(l1)
model = tf.keras.Model(inputs=inputs, outputs=[l1, l2])

def loss(y, y_hat):
    return tf.abs(y - y_hat)

model.compile(optimizer='adam', loss=[loss, loss], loss_weights=[0.2, 1.])

model.fit(
    features, 
    labels,
    validation_split=0.2,
    epochs=10,
    batch_size=32
)

i tried this shape too:
labels = np.random.random(size=[num_losses, num_examples, loss_dim])

if i convert labels into list, it works:

label_list = [l for l in labels.reshape(num_losses, num_examples, loss_dim)]
model.fit(
    features, 
    label_list,
    validation_split=0.2,
    epochs=10,
    batch_size=32
)

Main problem though for me is how to put this into a DataSet, which does not support anything besides tensors, i cannot use above solution for a DataSet, unless i missed some functionality.

@PrasadNR
Copy link

Is this keras-team/keras-preprocessing#295 somehow related?

@jvishnuvardhan jvishnuvardhan added the stat:awaiting tensorflower Status - Awaiting response from tensorflower label May 27, 2020
@jvishnuvardhan
Copy link
Contributor

@Strateus I updated your code for multiple outputs. It works as expected. Please check the gist here. Thanks!

@Bidski Is this still an issue for you? Also, please check Functional API guide for detailed guide on multiple-inputs and multiple-outputs model. Thanks!

@jvishnuvardhan jvishnuvardhan added stat:awaiting response Status - Awaiting response from author and removed stat:awaiting tensorflower Status - Awaiting response from tensorflower labels May 21, 2021
@google-ml-butler
Copy link

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you.

@google-ml-butler google-ml-butler bot added the stale This label marks the issue/pr stale - to be closed automatically if no activity label May 28, 2021
@google-ml-butler
Copy link

Closing as stale. Please reopen if you'd like to work on this further.

@google-ml-butler
Copy link

Are you satisfied with the resolution of your issue?
Yes
No

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
comp:keras Keras related issues stale This label marks the issue/pr stale - to be closed automatically if no activity stat:awaiting response Status - Awaiting response from author TF 2.1 for tracking issues in 2.1 release type:bug Bug
Projects
None yet
Development

No branches or pull requests

9 participants