Skip to content

Commit

Permalink
모델 추론까지(안드환경)
Browse files Browse the repository at this point in the history
  • Loading branch information
Lark6 committed Dec 12, 2024
1 parent f5ca719 commit 8686cc9
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 6 deletions.
87 changes: 81 additions & 6 deletions packages/example/lib/vision_detector_views/face_detector_view.dart
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import 'package:camera/camera.dart';
import 'package:flutter/material.dart';
import 'package:google_mlkit_face_detection/google_mlkit_face_detection.dart';
import 'package:image/image.dart' as img;
import 'package:image/image.dart';
import 'package:tflite_flutter/tflite_flutter.dart';
import 'detector_view.dart';
import 'painters/face_detector_painter.dart';
Expand Down Expand Up @@ -88,14 +89,26 @@ class _FaceDetectorViewState extends State<FaceDetectorView> {
// 사각형 영역을 구하기
final squareRect = _getSquareRect(faces, inputImage.metadata!.size);
if (squareRect != null) {
print("Updated squareRect: $squareRect");
image = cropImage(image, squareRect);
if (image != null) {
// 예측 수행 부분 추가 가능
// _predict(image);
print(image);
//print('Cropped Image width: ${image.width}, height: ${image.height}');

final input = convertToFloat32List(await getCnnInput(image)).reshape([1,224,224,3]);

// print('input shape: ${input.shape}');
// print('input dtype: ${input.runtimeType}'); // TensorFlow의 float32인가 확인

// print(_interpreter.getInputTensors()[0]);
// print(_interpreter.getOutputTensors()[0]);
// Run inference and get the output
final output = List.filled(1 * 1, 0.0).reshape([1, 1]); // Assuming output size 1000
// print('output shape: ${output.shape}');
_interpreter.run(input, output);

setState(() {
_text = 'Inference Result: ${output[0]}';
print(_text);
});
}

} else {
print("No square rect detected.");
}
Expand Down Expand Up @@ -219,5 +232,67 @@ class _FaceDetectorViewState extends State<FaceDetectorView> {
return null;
}
}
List<List<List<List<double>>>> _preprocessImage(img.Image image) {
// Step 1: Resize the image to 224x224
final img.Image resizedImage = img.copyResize(image, width: 224, height: 224);

// Step 2: Normalize the image pixel values to [0, 1]
final List<List<List<List<double>>>> input = List.generate(
1, // Batch size of 1
(_) => List.generate(
224, // Height
(y) => List.generate(
224, // Width
(x) {
// Extract pixel value (rgba format)
final pixel = resizedImage.getPixelSafe(x, y);

// Access RGBA components directly
final r = pixel.r / 255.0;
final g = pixel.g / 255.0;
final b = pixel.b / 255.0;

// Normalize to [-1, 1] if required
return [r * 2 - 1, g * 2 - 1, b * 2 - 1];
},
),
),
);

return input;
}




Float32List convertToFloat32List(List<List<List<double>>> input) {
final flattened = input.expand((row) => row.expand((col) => col)).toList();
return Float32List.fromList(flattened.map((e) => e.toDouble()).toList());
}

Future<List<List<List<double>>>> getCnnInput(img.Image image) async {
//print('cnn input');
// 이미지를 [224, 224] 크기로 리사이즈합니다.
img.Image resizedImage = img.copyResize(image, width: 224, height: 224);

// 결과를 저장할 배열을 초기화합니다.
List<List<List<double>>> result = List.generate(224,
(_) => List.generate(224,
(_) => List.filled(3, 0.0)));
// 픽셀 데이터를 [0, 1] 범위로 정규화하여 배열에 저장합니다.
for (int y = 0; y < resizedImage.height; y++) {
for (int x = 0; x < resizedImage.width; x++) {
Pixel pixel = resizedImage.getPixel(x, y);
double r = pixel[0] / 1;
double g = pixel[1] / 1;
double b = pixel[2] / 1;
result[y][x][0] = r; // Red
result[y][x][1] = g; // Green
result[y][x][2] = b; // Blue
}
}
//print(result.shape);
return result;
}

}
1 change: 1 addition & 0 deletions packages/example/pubspec.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,4 @@ flutter:
assets:
- assets/ml/
- assets/images/
- assets/food_resnet50.tflite

0 comments on commit 8686cc9

Please sign in to comment.