Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to get class for given text input? #1

Open
GraphGrailAi opened this issue Jan 22, 2017 · 7 comments
Open

How to get class for given text input? #1

GraphGrailAi opened this issue Jan 22, 2017 · 7 comments

Comments

@GraphGrailAi
Copy link

Could you provide some example code how to get class output for given text input?

I was able to get all code working with ./data/small_samples.json but output is accuracy percent - i need exact class name for every text

@jiegzhan
Copy link
Owner

  1. When you run train.py, the labels.json will be saved. labels.json is a list with all labels.

  2. When you run predict.py, take a look at line 63, if you print batch_predictions, it is a list with numbers, and each number is the index of labels.json.

For example, I printed batch_predictions: [6 6 6 4 6 4 3 6 4 6 3 4 1 2 3 2 3 2 4 0 4 4 4 3 4 6 4 4 1 4 0 6 2 4 4 6 3 3 1 3 4 4 3 4 3 6 3 6 6 6]
the first number in batch_predictions is 6, so the corresponding label for number 6 is labels.json[6], mortgage.

Hope this will help you find the corresponding labels.

@GraphGrailAi
Copy link
Author

Thanks for answer, i have done guess myself, and i tested that list of prediction index labels is all_predictions (not batch_predictions). When printed batch_predictions it return empty list []

predict.py from 63 line:

			for x_test_batch in batches:
				batch_predictions = sess.run(predictions, {input_x: x_test_batch, dropout_keep_prob: 1.0})
				all_predictions = np.concatenate([all_predictions, batch_predictions])

	if y_test is not None:
		y_test = np.argmax(y_test, axis=1)
		correct_predictions = sum(all_predictions == y_test)
		logging.critical('The batch_predictions is: {}'.format(batch_predictions))
		logging.critical('The all_predictions is: {}'.format(all_predictions))
		logging.critical('The y_test is: {}'.format(y_test)) # y_test is label list in labels.json
		logging.critical('The correct_predictions is: {}'.format(correct_predictions))
		logging.critical('The accuracy is: {}'.format(correct_predictions / float(len(y_test))))

output:

d:\Django\multi-class-text-classification-cnn>python predict.py ./trained_model_1485334811/ ./data/small_samples_my.json
CRITICAL:root:Loaded the trained model: d:\Django\multi-class-text-classification-cnn\trained_model_1485334811\checkpoints\model-300
INFO:root:The number of x_test: 5
INFO:root:The number of y_test: 5
CRITICAL:root:The batch_predictions is: []
CRITICAL:root:The all_predictions is: [ 10.  10.  10.   8.  10.]
CRITICAL:root:The y_test is: [10  6 10  8  9]
CRITICAL:root:The correct_predictions is: 3
CRITICAL:root:The accuracy is: 0.6

@jiegzhan
Copy link
Owner

Actually, for each batch, there will be a batch_predictions list, which will be appended to all_predictions.

Eventually, if you have 100 test examples, all predictions will have 100 numbers. Each number is the corresponding index in labels.json. You can get the actual label by referring to labels.json[index].

@GraphGrailAi
Copy link
Author

Thank you! i will create another issue for other question

@akki2825
Copy link

Has anyone figured this out?
I need to predict score for each of the class it predicts. Example: if the text belongs to a single class, I need to know the probability of the text belonging to that class.
Any help would keep me moving.

@vijaysaimutyala
Copy link

vijaysaimutyala commented Jan 29, 2018

@akki2825 Were you able to find a solution for predicting the probability of the classified text ?

@GraphGrailAi did you mean that you got accuracy for each class predict or the accuracy of the whole model ?

@Chinguun8
Copy link

Has anyone got a solution on printing the probability of each sentence prediction?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants