MainActivity.java 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. package de.lernapparat.zebraify;
  2. import android.content.Context;
  3. import android.content.Intent;
  4. import android.graphics.Bitmap;
  5. import android.provider.MediaStore;
  6. import android.support.v7.app.AppCompatActivity;
  7. import android.os.Bundle;
  8. import android.util.Log;
  9. import android.view.View;
  10. import android.widget.ImageView;
  11. import android.widget.TextView;
  12. import org.pytorch.IValue;
  13. import org.pytorch.Module;
  14. import org.pytorch.Tensor;
  15. import org.pytorch.torchvision.TensorImageUtils;
  16. import java.io.File;
  17. import java.io.FileOutputStream;
  18. import java.io.IOException;
  19. import java.io.InputStream;
  20. import java.io.OutputStream;
  21. public class MainActivity extends AppCompatActivity {
  22. static final int REQUEST_IMAGE_CAPTURE = 1;
  23. private org.pytorch.Module model;
  24. @Override
  25. protected void onCreate(Bundle savedInstanceState) {
  26. super.onCreate(savedInstanceState);
  27. setContentView(R.layout.activity_main);
  28. TextView tv= (TextView) findViewById(R.id.headline);
  29. tv.setOnClickListener(new View.OnClickListener() {
  30. public void onClick(View v) {
  31. Intent takePictureIntent = new Intent(MediaStore.ACTION_IMAGE_CAPTURE);
  32. // takePictureIntent.putExtra(android.provider.MediaStore.EXTRA_OUTPUT, android.provider.MediaStore.Images.Media.EXTERNAL_CONTENT_URI);
  33. if (takePictureIntent.resolveActivity(getPackageManager()) != null) {
  34. startActivityForResult(takePictureIntent, REQUEST_IMAGE_CAPTURE);
  35. }
  36. }
  37. });
  38. try {
  39. model = Module.load(assetFilePath(this, "traced_zebra_model.pt"));
  40. } catch (IOException e) {
  41. Log.e("Zebraify", "Error reading assets", e);
  42. finish();
  43. }
  44. }
  45. @Override
  46. protected void onActivityResult(int requestCode, int resultCode, Intent data) {
  47. if (requestCode == REQUEST_IMAGE_CAPTURE && resultCode == RESULT_OK) {
  48. // this gets called when the camera app got a picture
  49. Bitmap bitmap = (Bitmap) data.getExtras().get("data");
  50. final float[] means = {0.0f, 0.0f, 0.0f};
  51. final float[] stds = {1.0f, 1.0f, 1.0f};
  52. // preparing input tensor
  53. final Tensor inputTensor = TensorImageUtils.bitmapToFloat32Tensor(bitmap,
  54. means, stds);
  55. // running the model
  56. final Tensor outputTensor = model.forward(IValue.from(inputTensor)).toTensor();
  57. Bitmap output_bitmap = tensorToBitmap(outputTensor, means, stds, Bitmap.Config.RGB_565);
  58. ImageView image_view = (ImageView) findViewById(R.id.imageView);
  59. image_view.setImageBitmap(output_bitmap);
  60. }
  61. }
  62. // This is intended to be the inverse of bitmapToFloat32Tensor
  63. static Bitmap tensorToBitmap(Tensor tensor, float[] normMeanRGB, float[] normStdRGB, Bitmap.Config bc) {
  64. final float[] outputArray = tensor.getDataAsFloatArray();
  65. final long[] shape = tensor.shape();
  66. int width = (int) shape[shape.length - 1];
  67. int height = (int) shape[shape.length - 2];
  68. Bitmap output_bitmap = Bitmap.createBitmap(width, height, bc);
  69. int numPixels = width * height;
  70. int[] pixels = new int[numPixels];
  71. for (int i = 0; i < numPixels; i++) {
  72. pixels[i] = ((int) ((outputArray[0 * numPixels + i] * normStdRGB[0] + normMeanRGB[0]) * 255 + 0.49999) << 16)
  73. + ((int) ((outputArray[1 * numPixels + i] * normStdRGB[1] + normMeanRGB[1]) * 255 + 0.49999) << 8)
  74. + ((int) ((outputArray[2 * numPixels + i] * normStdRGB[2] + normMeanRGB[2]) * 255 + 0.49999));
  75. }
  76. output_bitmap.setPixels(pixels, 0, width, 0, 0, width, height);
  77. return output_bitmap;
  78. }
  79. /**
  80. * Taken from PyTorch's HelloWorld Android app.
  81. *
  82. * Copies specified asset to the file in /files app directory and returns this file absolute path.
  83. *
  84. * @return absolute file path
  85. */
  86. public static String assetFilePath(Context context, String assetName) throws IOException {
  87. File file = new File(context.getFilesDir(), assetName);
  88. if (false && file.exists() && file.length() > 0) {
  89. return file.getAbsolutePath();
  90. }
  91. try (InputStream is = context.getAssets().open(assetName)) {
  92. try (OutputStream os = new FileOutputStream(file, false)) {
  93. byte[] buffer = new byte[4 * 1024];
  94. int read;
  95. while ((read = is.read(buffer)) != -1) {
  96. os.write(buffer, 0, read);
  97. }
  98. os.flush();
  99. }
  100. return file.getAbsolutePath();
  101. }
  102. }
  103. }