From ebaa4266f35fac2d4456950876922348a28d2da2 Mon Sep 17 00:00:00 2001 From: Esun Kim Date: Wed, 6 May 2026 09:48:47 -0700 Subject: [PATCH 1/5] No tensorflow 1 --- python/python_requirements.in | 1 + python/python_requirements_3_10.txt | 97 ++++++++++++++++++- python/python_requirements_3_11.txt | 47 ++++++++- python/python_requirements_3_12.txt | 35 ++++++- python/python_requirements_3_13.txt | 33 ++++++- tensorflow/lite/micro/tools/BUILD | 2 +- .../micro/tools/layer_by_layer_debugger.py | 42 ++++++-- tensorflow/lite/python/BUILD | 1 - tensorflow/lite/python/schema_util.py | 6 -- tensorflow/lite/tools/BUILD | 4 - tensorflow/lite/tools/flatbuffer_utils.py | 8 +- .../lite/tools/flatbuffer_utils_test.py | 22 ++--- tensorflow/lite/tools/visualize_test.py | 10 +- 13 files changed, 262 insertions(+), 46 deletions(-) diff --git a/python/python_requirements.in b/python/python_requirements.in index 11c0f0e6063..a5abe23dcae 100644 --- a/python/python_requirements.in +++ b/python/python_requirements.in @@ -27,6 +27,7 @@ # it is run. absl_py +ai-edge-litert bitarray mako numpy diff --git a/python/python_requirements_3_10.txt b/python/python_requirements_3_10.txt index e007d68afe5..2c25821d468 100644 --- a/python/python_requirements_3_10.txt +++ b/python/python_requirements_3_10.txt @@ -1,5 +1,5 @@ # -# This file is autogenerated by pip-compile with Python 3.13 +# This file is autogenerated by pip-compile with Python 3.10 # by the following command: # # bazel run //python:requirements_3_10.update @@ -12,10 +12,36 @@ absl-py==2.3.1 \ # keras # tensorboard # tensorflow +ai-edge-litert==2.1.4 \ + --hash=sha256:0152b9e9712995931cb7ccb54909481e6c3d20e8e7b8f9853972ffc6534b559c \ + --hash=sha256:1bed8ccd8dfc4f2ec388bc79cb90920c261a244995ce378d396fb8967b7db640 \ + --hash=sha256:1dc178c0a9e59df865fdfb536835c15a56224e78047006ad3c68e855dbbc12d8 \ + --hash=sha256:25aafb358229d9c87772b8448488ff373bbc42de0c9e1e62ed22202bdf6007af \ + --hash=sha256:2772aa69e7f1934cfc5bda3ee8bcd52a54d50cc3dfb853c7591da08d92f16f09 \ + --hash=sha256:4e3ccc847dbca7acdf8178443729b6d9d3fcb4a5f0dd29f6a834e68fcab4b04a \ + --hash=sha256:58fbfcaf04475e0f6a128ce1f52dffa5660d88d5f92788e3f40b5bba1fb02eab \ + --hash=sha256:72d14997d3ae976bf325c9831f1a7bf42cddfe575d8351ff3c516a294aa5d810 \ + --hash=sha256:95b5eb86874a2397c78be5737a168b425556467b2a58aab3f3fbb0ff7678a665 \ + --hash=sha256:9f42035b35c3de5062210098a09caa18428c8b1b1c8bd4ffc3206e6581767d1a \ + --hash=sha256:a4080394830873db3d9c3801801c65147dcd59eec26389c849477dae55d3d3cb \ + --hash=sha256:b5f50d754c0b2eb4c40074422f89503e6b85797a83356a2c7677ff44b6281f72 \ + --hash=sha256:c280d21f1111feb3219283417c1fb5f5a4bbf9e16a214fd96d6220fbc62dec95 \ + --hash=sha256:d9ccf39b6603233e92973d591065dd7a62e875c8b7a12d2c0558f8f66bd4f434 \ + --hash=sha256:f1026dc8c4249defc05d618feb1ee4827561d761cdf9838c9edffaca272bf0d5 \ + --hash=sha256:f2f07d27211a6b64ed374733bcbdd78dcf0cdc70de5c5f7d0ce7f198d4890a23 + # via -r python/python_requirements.in astunparse==1.6.3 \ --hash=sha256:5ad93a8456f0d084c3456d059fd9a92cce667963232cbf763eac3bc5b7940872 \ --hash=sha256:c2652417f2c8b5bb325c885ae329bdf3f86424075c4fd1a128674bc6fba4b8e8 # via tensorflow +backports-strenum==1.3.1 \ + --hash=sha256:77c52407342898497714f0596e86188bb7084f89063226f4ba66863482f42414 \ + --hash=sha256:cdcfe36dc897e2615dc793b7d3097f54d359918fc448754a517e6f23044ccf83 + # via ai-edge-litert +backports-tarfile==1.2.0 \ + --hash=sha256:77e284d754527b01fb1e6fa8a1afe577858ebe4e9dad8919e34c862cb399bc34 \ + --hash=sha256:d75e02c268746e1b8144c278978b6e98e85de6ad16f8e4b0844a154557eca991 + # via jaraco-context bitarray==3.8.0 \ --hash=sha256:004d518fa410e6da43386d20e07b576a41eb417ac67abf9f30fa75e125697199 \ --hash=sha256:014df8a9430276862392ac5d471697de042367996c49f32d0008585d2c60755a \ @@ -393,7 +419,9 @@ docutils==0.22.4 \ # via readme-renderer flatbuffers==25.12.19 \ --hash=sha256:7634f50c427838bb021c2d66a3d1168e9d199b0607e6329399f04846d42e20b4 - # via tensorflow + # via + # ai-edge-litert + # tensorflow gast==0.7.0 \ --hash=sha256:0bb14cd1b806722e91ddbab6fb86bba148c22b40e7ff11e248974e04c8adfdae \ --hash=sha256:99cbf1365633a74099f69c59bd650476b96baa5ef196fec88032b00b31ba36f7 @@ -520,6 +548,10 @@ idna==3.11 \ --hash=sha256:771a87f49d9defaf64091e6e6fe9c18d4833f140bd19464795bc32d966ca37ea \ --hash=sha256:795dafcc9c04ed0c1fb032c2aa73654d8e8c5023a7df64a53f39190ada629902 # via requests +importlib-metadata==9.0.0 \ + --hash=sha256:2d21d1cc5a017bd0559e36150c21c830ab1dc304dedd1b7ea85d20f45ef3edd7 \ + --hash=sha256:a4f57ab599e6a2e3016d7595cfd72eb4661a5106e787a95bcc90c7105b831efc + # via keyring jaraco-classes==3.4.0 \ --hash=sha256:47a024b51d0239c0dd8c8540c6c7f484be3b8fcf0b2d85c13825780d3b3f3acd \ --hash=sha256:f662826b6bed8cace05e7ff873ce0f9283b5c924470fe664fff1c2f00f581790 @@ -806,6 +838,7 @@ numpy==2.2.6 \ --hash=sha256:fee4236c876c4e8369388054d02d0e9bb84821feb1a64dd59e137e6511a551f8 # via # -r python/python_requirements.in + # ai-edge-litert # h5py # keras # ml-dtypes @@ -1041,6 +1074,7 @@ protobuf==6.33.2 \ --hash=sha256:fc2a0e8b05b180e5fc0dd1559fe8ebdae21a27e81ac77728fb6c42b12c7419b4 # via # -r python/python_requirements.in + # ai-edge-litert # tensorboard # tensorflow pycparser==2.23 \ @@ -1200,6 +1234,59 @@ termcolor==3.3.0 \ --hash=sha256:348871ca648ec6a9a983a13ab626c0acce02f515b9e1983332b17af7979521c5 \ --hash=sha256:cf642efadaf0a8ebbbf4bc7a31cec2f9b5f21a9f726f4ccbb08192c9c26f43a5 # via tensorflow +tomli==2.4.1 \ + --hash=sha256:01f520d4f53ef97964a240a035ec2a869fe1a37dde002b57ebc4417a27ccd853 \ + --hash=sha256:0d85819802132122da43cb86656f8d1f8c6587d54ae7dcaf30e90533028b49fe \ + --hash=sha256:136443dbd7e1dee43c68ac2694fde36b2849865fa258d39bf822c10e8068eac5 \ + --hash=sha256:1d8591993e228b0c930c4bb0db464bdad97b3289fb981255d6c9a41aedc84b2d \ + --hash=sha256:2190f2e9dd7508d2a90ded5ed369255980a1bcdd58e52f7fe24b8162bf9fedbd \ + --hash=sha256:2c1c351919aca02858f740c6d33adea0c5deea37f9ecca1cc1ef9e884a619d26 \ + --hash=sha256:36d2bd2ad5fb9eaddba5226aa02c8ec3fa4f192631e347b3ed28186d43be6b54 \ + --hash=sha256:3d48a93ee1c9b79c04bb38772ee1b64dcf18ff43085896ea460ca8dec96f35f6 \ + --hash=sha256:47149d5bd38761ac8be13a84864bf0b7b70bc051806bc3669ab1cbc56216b23c \ + --hash=sha256:4ab97e64ccda8756376892c53a72bd1f964e519c77236368527f758fbc36a53a \ + --hash=sha256:4b605484e43cdc43f0954ddae319fb75f04cc10dd80d830540060ee7cd0243cd \ + --hash=sha256:504aa796fe0569bb43171066009ead363de03675276d2d121ac1a4572397870f \ + --hash=sha256:51529d40e3ca50046d7606fa99ce3956a617f9b36380da3b7f0dd3dd28e68cb5 \ + --hash=sha256:52c8ef851d9a240f11a88c003eacb03c31fc1c9c4ec64a99a0f922b93874fda9 \ + --hash=sha256:559db847dc486944896521f68d8190be1c9e719fced785720d2216fe7022b662 \ + --hash=sha256:5a881ab208c0baf688221f8cecc5401bd291d67e38a1ac884d6736cbcd8247e9 \ + --hash=sha256:5cb41aa38891e073ee49d55fbc7839cfdb2bc0e600add13874d048c94aadddd1 \ + --hash=sha256:5e262d41726bc187e69af7825504c933b6794dc3fbd5945e41a79bb14c31f585 \ + --hash=sha256:5ee18d9ebdb417e384b58fe414e8d6af9f4e7a0ae761519fb50f721de398dd4e \ + --hash=sha256:7008df2e7655c495dd12d2a4ad038ff878d4ca4b81fccaf82b714e07eae4402c \ + --hash=sha256:734e20b57ba95624ecf1841e72b53f6e186355e216e5412de414e3c51e5e3c41 \ + --hash=sha256:7c7e1a961a0b2f2472c1ac5b69affa0ae1132c39adcb67aba98568702b9cc23f \ + --hash=sha256:7f86fd587c4ed9dd76f318225e7d9b29cfc5a9d43de44e5754db8d1128487085 \ + --hash=sha256:7f94b27a62cfad8496c8d2513e1a222dd446f095fca8987fceef261225538a15 \ + --hash=sha256:88dceee75c2c63af144e456745e10101eb67361050196b0b6af5d717254dddf7 \ + --hash=sha256:8a650c2dbafa08d42e51ba0b62740dae4ecb9338eefa093aa5c78ceb546fcd5c \ + --hash=sha256:8d65a2fbf9d2f8352685bc1364177ee3923d6baf5e7f43ea4959d7d8bc326a36 \ + --hash=sha256:96481a5786729fd470164b47cdb3e0e58062a496f455ee41b4403be77cb5a076 \ + --hash=sha256:a120733b01c45e9a0c34aeef92bf0cf1d56cfe81ed9d47d562f9ed591a9828ac \ + --hash=sha256:b1d22e6e9387bf4739fbe23bfa80e93f6b0373a7f1b96c6227c32bef95a4d7a8 \ + --hash=sha256:b8c198f8c1805dc42708689ed6864951fd2494f924149d3e4bce7710f8eb5232 \ + --hash=sha256:c2541745709bad0264b7d4705ad453b76ccd191e64aa6f0fc66b69a293a45ece \ + --hash=sha256:c742f741d58a28940ce01d58f0ab2ea3ced8b12402f162f4d534dfe18ba1cd6a \ + --hash=sha256:c7f2c7f2b9ca6bdeef8f0fa897f8e05085923eb091721675170254cbc5b02897 \ + --hash=sha256:d312ef37c91508b0ab2cee7da26ec0b3ed2f03ce12bd87a588d771ae15dcf82d \ + --hash=sha256:d4d8fe59808a54658fcc0160ecfb1b30f9089906c50b23bcb4c69eddc19ec2b4 \ + --hash=sha256:da25dc3563bff5965356133435b757a795a17b17d01dbc0f42fb32447ddfd917 \ + --hash=sha256:eab21f45c7f66c13f2a9e0e1535309cee140182a9cdae1e041d02e47291e8396 \ + --hash=sha256:eb0dc4e38e6a1fd579e5d50369aa2e10acfc9cace504579b2faabb478e76941a \ + --hash=sha256:ec9bfaf3ad2df51ace80688143a6a4ebc09a248f6ff781a9945e51937008fcbc \ + --hash=sha256:ede3e6487c5ef5d28634ba3f31f989030ad6af71edfb0055cbbd14189ff240ba \ + --hash=sha256:f3c6818a1a86dd6dca7ddcaaf76947d5ba31aecc28cb1b67009a5877c9a64f3f \ + --hash=sha256:f758f1b9299d059cc3f6546ae2af89670cb1c4d48ea29c3cacc4fe7de3058257 \ + --hash=sha256:f8f0fc26ec2cc2b965b7a3b87cd19c5c6b8c5e5f436b984e85f486d652285c30 \ + --hash=sha256:fd0409a3653af6c147209d267a0e4243f0ae46b011aa978b1080359fddc9b6cf \ + --hash=sha256:ff18e6a727ee0ab0388507b89d1bc6a22b138d1e2fa56d1ad494586d61d2eae9 \ + --hash=sha256:ff2983983d34813c1aeb0fa89091e76c3a22889ee83ab27c5eeb45100560c049 + # via yapf +tqdm==4.67.3 \ + --hash=sha256:7d825f03f89244ef73f1d4ce193cb1774a8179fd96f31d7e1dcde62092b960bb \ + --hash=sha256:ee1e4c0e59148062281c49d80b25b67771a127c85fc9676d3be5f243206826bf + # via ai-edge-litert twine==6.2.0 \ --hash=sha256:418ebf08ccda9a8caaebe414433b0ba5e25eb5e4a927667122fbe8f829f985d8 \ --hash=sha256:e5ed0d2fd70c9959770dce51c8f39c8945c574e18173a7b81802dab51b4b75cf @@ -1208,6 +1295,8 @@ typing-extensions==4.15.0 \ --hash=sha256:0cea48d173cc12fa28ecabc3b837ea3cf6f38c6d1136f85cbaaf598984861466 \ --hash=sha256:f0fa19c6845758ab08074a0cfa8b7aecb71c999ca73d62883bc25cc018c4e548 # via + # ai-edge-litert + # cryptography # grpcio # optree # tensorflow @@ -1338,6 +1427,10 @@ yapf==0.43.0 \ --hash=sha256:00d3aa24bfedff9420b2e0d5d9f5ab6d9d4268e72afbf59bb3fa542781d5218e \ --hash=sha256:224faffbc39c428cb095818cf6ef5511fdab6f7430a10783fdfb292ccf2852ca # via -r python/python_requirements.in +zipp==3.23.1 \ + --hash=sha256:0b3596c50a5c700c9cb40ba8d86d9f2cc4807e9bedb06bcdf7fac85633e444dc \ + --hash=sha256:32120e378d32cd9714ad503c1d024619063ec28aad2248dc6672ad13edfa5110 + # via importlib-metadata # The following packages are considered to be unsafe in a requirements file: setuptools==80.9.0 \ diff --git a/python/python_requirements_3_11.txt b/python/python_requirements_3_11.txt index 53e62824e4c..fbbfcafb161 100644 --- a/python/python_requirements_3_11.txt +++ b/python/python_requirements_3_11.txt @@ -1,5 +1,5 @@ # -# This file is autogenerated by pip-compile with Python 3.13 +# This file is autogenerated by pip-compile with Python 3.11 # by the following command: # # bazel run //python:requirements_3_11.update @@ -12,10 +12,36 @@ absl-py==2.3.1 \ # keras # tensorboard # tensorflow +ai-edge-litert==2.1.4 \ + --hash=sha256:0152b9e9712995931cb7ccb54909481e6c3d20e8e7b8f9853972ffc6534b559c \ + --hash=sha256:1bed8ccd8dfc4f2ec388bc79cb90920c261a244995ce378d396fb8967b7db640 \ + --hash=sha256:1dc178c0a9e59df865fdfb536835c15a56224e78047006ad3c68e855dbbc12d8 \ + --hash=sha256:25aafb358229d9c87772b8448488ff373bbc42de0c9e1e62ed22202bdf6007af \ + --hash=sha256:2772aa69e7f1934cfc5bda3ee8bcd52a54d50cc3dfb853c7591da08d92f16f09 \ + --hash=sha256:4e3ccc847dbca7acdf8178443729b6d9d3fcb4a5f0dd29f6a834e68fcab4b04a \ + --hash=sha256:58fbfcaf04475e0f6a128ce1f52dffa5660d88d5f92788e3f40b5bba1fb02eab \ + --hash=sha256:72d14997d3ae976bf325c9831f1a7bf42cddfe575d8351ff3c516a294aa5d810 \ + --hash=sha256:95b5eb86874a2397c78be5737a168b425556467b2a58aab3f3fbb0ff7678a665 \ + --hash=sha256:9f42035b35c3de5062210098a09caa18428c8b1b1c8bd4ffc3206e6581767d1a \ + --hash=sha256:a4080394830873db3d9c3801801c65147dcd59eec26389c849477dae55d3d3cb \ + --hash=sha256:b5f50d754c0b2eb4c40074422f89503e6b85797a83356a2c7677ff44b6281f72 \ + --hash=sha256:c280d21f1111feb3219283417c1fb5f5a4bbf9e16a214fd96d6220fbc62dec95 \ + --hash=sha256:d9ccf39b6603233e92973d591065dd7a62e875c8b7a12d2c0558f8f66bd4f434 \ + --hash=sha256:f1026dc8c4249defc05d618feb1ee4827561d761cdf9838c9edffaca272bf0d5 \ + --hash=sha256:f2f07d27211a6b64ed374733bcbdd78dcf0cdc70de5c5f7d0ce7f198d4890a23 + # via -r python/python_requirements.in astunparse==1.6.3 \ --hash=sha256:5ad93a8456f0d084c3456d059fd9a92cce667963232cbf763eac3bc5b7940872 \ --hash=sha256:c2652417f2c8b5bb325c885ae329bdf3f86424075c4fd1a128674bc6fba4b8e8 # via tensorflow +backports-strenum==1.2.8 \ + --hash=sha256:4dd47365fd427ac8028aeb1ad3628ea38e67c4d0336ceebd5c0f113e0c487ce9 \ + --hash=sha256:fc297cb26971f7d5e15a478a06a78575197f81daea47975771b1aae996dcccf4 + # via ai-edge-litert +backports-tarfile==1.2.0 \ + --hash=sha256:77e284d754527b01fb1e6fa8a1afe577858ebe4e9dad8919e34c862cb399bc34 \ + --hash=sha256:d75e02c268746e1b8144c278978b6e98e85de6ad16f8e4b0844a154557eca991 + # via jaraco-context bitarray==3.8.0 \ --hash=sha256:004d518fa410e6da43386d20e07b576a41eb417ac67abf9f30fa75e125697199 \ --hash=sha256:014df8a9430276862392ac5d471697de042367996c49f32d0008585d2c60755a \ @@ -393,7 +419,9 @@ docutils==0.22.4 \ # via readme-renderer flatbuffers==25.12.19 \ --hash=sha256:7634f50c427838bb021c2d66a3d1168e9d199b0607e6329399f04846d42e20b4 - # via tensorflow + # via + # ai-edge-litert + # tensorflow gast==0.7.0 \ --hash=sha256:0bb14cd1b806722e91ddbab6fb86bba148c22b40e7ff11e248974e04c8adfdae \ --hash=sha256:99cbf1365633a74099f69c59bd650476b96baa5ef196fec88032b00b31ba36f7 @@ -520,6 +548,10 @@ idna==3.11 \ --hash=sha256:771a87f49d9defaf64091e6e6fe9c18d4833f140bd19464795bc32d966ca37ea \ --hash=sha256:795dafcc9c04ed0c1fb032c2aa73654d8e8c5023a7df64a53f39190ada629902 # via requests +importlib-metadata==9.0.0 \ + --hash=sha256:2d21d1cc5a017bd0559e36150c21c830ab1dc304dedd1b7ea85d20f45ef3edd7 \ + --hash=sha256:a4f57ab599e6a2e3016d7595cfd72eb4661a5106e787a95bcc90c7105b831efc + # via keyring jaraco-classes==3.4.0 \ --hash=sha256:47a024b51d0239c0dd8c8540c6c7f484be3b8fcf0b2d85c13825780d3b3f3acd \ --hash=sha256:f662826b6bed8cace05e7ff873ce0f9283b5c924470fe664fff1c2f00f581790 @@ -823,6 +855,7 @@ numpy==2.4.0 \ --hash=sha256:f935c4493eda9069851058fa0d9e39dbf6286be690066509305e52912714dbb2 # via # -r python/python_requirements.in + # ai-edge-litert # h5py # keras # ml-dtypes @@ -1058,6 +1091,7 @@ protobuf==6.33.2 \ --hash=sha256:fc2a0e8b05b180e5fc0dd1559fe8ebdae21a27e81ac77728fb6c42b12c7419b4 # via # -r python/python_requirements.in + # ai-edge-litert # tensorboard # tensorflow pycparser==2.23 \ @@ -1217,6 +1251,10 @@ termcolor==3.3.0 \ --hash=sha256:348871ca648ec6a9a983a13ab626c0acce02f515b9e1983332b17af7979521c5 \ --hash=sha256:cf642efadaf0a8ebbbf4bc7a31cec2f9b5f21a9f726f4ccbb08192c9c26f43a5 # via tensorflow +tqdm==4.67.3 \ + --hash=sha256:7d825f03f89244ef73f1d4ce193cb1774a8179fd96f31d7e1dcde62092b960bb \ + --hash=sha256:ee1e4c0e59148062281c49d80b25b67771a127c85fc9676d3be5f243206826bf + # via ai-edge-litert twine==6.2.0 \ --hash=sha256:418ebf08ccda9a8caaebe414433b0ba5e25eb5e4a927667122fbe8f829f985d8 \ --hash=sha256:e5ed0d2fd70c9959770dce51c8f39c8945c574e18173a7b81802dab51b4b75cf @@ -1225,6 +1263,7 @@ typing-extensions==4.15.0 \ --hash=sha256:0cea48d173cc12fa28ecabc3b837ea3cf6f38c6d1136f85cbaaf598984861466 \ --hash=sha256:f0fa19c6845758ab08074a0cfa8b7aecb71c999ca73d62883bc25cc018c4e548 # via + # ai-edge-litert # grpcio # optree # tensorflow @@ -1355,6 +1394,10 @@ yapf==0.43.0 \ --hash=sha256:00d3aa24bfedff9420b2e0d5d9f5ab6d9d4268e72afbf59bb3fa542781d5218e \ --hash=sha256:224faffbc39c428cb095818cf6ef5511fdab6f7430a10783fdfb292ccf2852ca # via -r python/python_requirements.in +zipp==3.23.1 \ + --hash=sha256:0b3596c50a5c700c9cb40ba8d86d9f2cc4807e9bedb06bcdf7fac85633e444dc \ + --hash=sha256:32120e378d32cd9714ad503c1d024619063ec28aad2248dc6672ad13edfa5110 + # via importlib-metadata # The following packages are considered to be unsafe in a requirements file: setuptools==80.9.0 \ diff --git a/python/python_requirements_3_12.txt b/python/python_requirements_3_12.txt index a1375828991..29738da227f 100644 --- a/python/python_requirements_3_12.txt +++ b/python/python_requirements_3_12.txt @@ -1,5 +1,5 @@ # -# This file is autogenerated by pip-compile with Python 3.13 +# This file is autogenerated by pip-compile with Python 3.12 # by the following command: # # bazel run //python:requirements_3_12.update @@ -12,10 +12,32 @@ absl-py==2.3.1 \ # keras # tensorboard # tensorflow +ai-edge-litert==2.1.4 \ + --hash=sha256:0152b9e9712995931cb7ccb54909481e6c3d20e8e7b8f9853972ffc6534b559c \ + --hash=sha256:1bed8ccd8dfc4f2ec388bc79cb90920c261a244995ce378d396fb8967b7db640 \ + --hash=sha256:1dc178c0a9e59df865fdfb536835c15a56224e78047006ad3c68e855dbbc12d8 \ + --hash=sha256:25aafb358229d9c87772b8448488ff373bbc42de0c9e1e62ed22202bdf6007af \ + --hash=sha256:2772aa69e7f1934cfc5bda3ee8bcd52a54d50cc3dfb853c7591da08d92f16f09 \ + --hash=sha256:4e3ccc847dbca7acdf8178443729b6d9d3fcb4a5f0dd29f6a834e68fcab4b04a \ + --hash=sha256:58fbfcaf04475e0f6a128ce1f52dffa5660d88d5f92788e3f40b5bba1fb02eab \ + --hash=sha256:72d14997d3ae976bf325c9831f1a7bf42cddfe575d8351ff3c516a294aa5d810 \ + --hash=sha256:95b5eb86874a2397c78be5737a168b425556467b2a58aab3f3fbb0ff7678a665 \ + --hash=sha256:9f42035b35c3de5062210098a09caa18428c8b1b1c8bd4ffc3206e6581767d1a \ + --hash=sha256:a4080394830873db3d9c3801801c65147dcd59eec26389c849477dae55d3d3cb \ + --hash=sha256:b5f50d754c0b2eb4c40074422f89503e6b85797a83356a2c7677ff44b6281f72 \ + --hash=sha256:c280d21f1111feb3219283417c1fb5f5a4bbf9e16a214fd96d6220fbc62dec95 \ + --hash=sha256:d9ccf39b6603233e92973d591065dd7a62e875c8b7a12d2c0558f8f66bd4f434 \ + --hash=sha256:f1026dc8c4249defc05d618feb1ee4827561d761cdf9838c9edffaca272bf0d5 \ + --hash=sha256:f2f07d27211a6b64ed374733bcbdd78dcf0cdc70de5c5f7d0ce7f198d4890a23 + # via -r python/python_requirements.in astunparse==1.6.3 \ --hash=sha256:5ad93a8456f0d084c3456d059fd9a92cce667963232cbf763eac3bc5b7940872 \ --hash=sha256:c2652417f2c8b5bb325c885ae329bdf3f86424075c4fd1a128674bc6fba4b8e8 # via tensorflow +backports-strenum==1.2.8 \ + --hash=sha256:4dd47365fd427ac8028aeb1ad3628ea38e67c4d0336ceebd5c0f113e0c487ce9 \ + --hash=sha256:fc297cb26971f7d5e15a478a06a78575197f81daea47975771b1aae996dcccf4 + # via ai-edge-litert bitarray==3.8.0 \ --hash=sha256:004d518fa410e6da43386d20e07b576a41eb417ac67abf9f30fa75e125697199 \ --hash=sha256:014df8a9430276862392ac5d471697de042367996c49f32d0008585d2c60755a \ @@ -393,7 +415,9 @@ docutils==0.22.4 \ # via readme-renderer flatbuffers==25.12.19 \ --hash=sha256:7634f50c427838bb021c2d66a3d1168e9d199b0607e6329399f04846d42e20b4 - # via tensorflow + # via + # ai-edge-litert + # tensorflow gast==0.7.0 \ --hash=sha256:0bb14cd1b806722e91ddbab6fb86bba148c22b40e7ff11e248974e04c8adfdae \ --hash=sha256:99cbf1365633a74099f69c59bd650476b96baa5ef196fec88032b00b31ba36f7 @@ -823,6 +847,7 @@ numpy==2.4.0 \ --hash=sha256:f935c4493eda9069851058fa0d9e39dbf6286be690066509305e52912714dbb2 # via # -r python/python_requirements.in + # ai-edge-litert # h5py # keras # ml-dtypes @@ -1058,6 +1083,7 @@ protobuf==6.33.2 \ --hash=sha256:fc2a0e8b05b180e5fc0dd1559fe8ebdae21a27e81ac77728fb6c42b12c7419b4 # via # -r python/python_requirements.in + # ai-edge-litert # tensorboard # tensorflow pycparser==2.23 \ @@ -1217,6 +1243,10 @@ termcolor==3.3.0 \ --hash=sha256:348871ca648ec6a9a983a13ab626c0acce02f515b9e1983332b17af7979521c5 \ --hash=sha256:cf642efadaf0a8ebbbf4bc7a31cec2f9b5f21a9f726f4ccbb08192c9c26f43a5 # via tensorflow +tqdm==4.67.3 \ + --hash=sha256:7d825f03f89244ef73f1d4ce193cb1774a8179fd96f31d7e1dcde62092b960bb \ + --hash=sha256:ee1e4c0e59148062281c49d80b25b67771a127c85fc9676d3be5f243206826bf + # via ai-edge-litert twine==6.2.0 \ --hash=sha256:418ebf08ccda9a8caaebe414433b0ba5e25eb5e4a927667122fbe8f829f985d8 \ --hash=sha256:e5ed0d2fd70c9959770dce51c8f39c8945c574e18173a7b81802dab51b4b75cf @@ -1225,6 +1255,7 @@ typing-extensions==4.15.0 \ --hash=sha256:0cea48d173cc12fa28ecabc3b837ea3cf6f38c6d1136f85cbaaf598984861466 \ --hash=sha256:f0fa19c6845758ab08074a0cfa8b7aecb71c999ca73d62883bc25cc018c4e548 # via + # ai-edge-litert # grpcio # optree # tensorflow diff --git a/python/python_requirements_3_13.txt b/python/python_requirements_3_13.txt index fef94f18c2b..5e8b4922b4f 100644 --- a/python/python_requirements_3_13.txt +++ b/python/python_requirements_3_13.txt @@ -12,10 +12,32 @@ absl-py==2.3.1 \ # keras # tensorboard # tensorflow +ai-edge-litert==2.1.4 \ + --hash=sha256:0152b9e9712995931cb7ccb54909481e6c3d20e8e7b8f9853972ffc6534b559c \ + --hash=sha256:1bed8ccd8dfc4f2ec388bc79cb90920c261a244995ce378d396fb8967b7db640 \ + --hash=sha256:1dc178c0a9e59df865fdfb536835c15a56224e78047006ad3c68e855dbbc12d8 \ + --hash=sha256:25aafb358229d9c87772b8448488ff373bbc42de0c9e1e62ed22202bdf6007af \ + --hash=sha256:2772aa69e7f1934cfc5bda3ee8bcd52a54d50cc3dfb853c7591da08d92f16f09 \ + --hash=sha256:4e3ccc847dbca7acdf8178443729b6d9d3fcb4a5f0dd29f6a834e68fcab4b04a \ + --hash=sha256:58fbfcaf04475e0f6a128ce1f52dffa5660d88d5f92788e3f40b5bba1fb02eab \ + --hash=sha256:72d14997d3ae976bf325c9831f1a7bf42cddfe575d8351ff3c516a294aa5d810 \ + --hash=sha256:95b5eb86874a2397c78be5737a168b425556467b2a58aab3f3fbb0ff7678a665 \ + --hash=sha256:9f42035b35c3de5062210098a09caa18428c8b1b1c8bd4ffc3206e6581767d1a \ + --hash=sha256:a4080394830873db3d9c3801801c65147dcd59eec26389c849477dae55d3d3cb \ + --hash=sha256:b5f50d754c0b2eb4c40074422f89503e6b85797a83356a2c7677ff44b6281f72 \ + --hash=sha256:c280d21f1111feb3219283417c1fb5f5a4bbf9e16a214fd96d6220fbc62dec95 \ + --hash=sha256:d9ccf39b6603233e92973d591065dd7a62e875c8b7a12d2c0558f8f66bd4f434 \ + --hash=sha256:f1026dc8c4249defc05d618feb1ee4827561d761cdf9838c9edffaca272bf0d5 \ + --hash=sha256:f2f07d27211a6b64ed374733bcbdd78dcf0cdc70de5c5f7d0ce7f198d4890a23 + # via -r python/python_requirements.in astunparse==1.6.3 \ --hash=sha256:5ad93a8456f0d084c3456d059fd9a92cce667963232cbf763eac3bc5b7940872 \ --hash=sha256:c2652417f2c8b5bb325c885ae329bdf3f86424075c4fd1a128674bc6fba4b8e8 # via tensorflow +backports-strenum==1.2.8 \ + --hash=sha256:4dd47365fd427ac8028aeb1ad3628ea38e67c4d0336ceebd5c0f113e0c487ce9 \ + --hash=sha256:fc297cb26971f7d5e15a478a06a78575197f81daea47975771b1aae996dcccf4 + # via ai-edge-litert bitarray==3.8.0 \ --hash=sha256:004d518fa410e6da43386d20e07b576a41eb417ac67abf9f30fa75e125697199 \ --hash=sha256:014df8a9430276862392ac5d471697de042367996c49f32d0008585d2c60755a \ @@ -393,7 +415,9 @@ docutils==0.22.4 \ # via readme-renderer flatbuffers==25.12.19 \ --hash=sha256:7634f50c427838bb021c2d66a3d1168e9d199b0607e6329399f04846d42e20b4 - # via tensorflow + # via + # ai-edge-litert + # tensorflow gast==0.7.0 \ --hash=sha256:0bb14cd1b806722e91ddbab6fb86bba148c22b40e7ff11e248974e04c8adfdae \ --hash=sha256:99cbf1365633a74099f69c59bd650476b96baa5ef196fec88032b00b31ba36f7 @@ -823,6 +847,7 @@ numpy==2.4.0 \ --hash=sha256:f935c4493eda9069851058fa0d9e39dbf6286be690066509305e52912714dbb2 # via # -r python/python_requirements.in + # ai-edge-litert # h5py # keras # ml-dtypes @@ -1058,6 +1083,7 @@ protobuf==6.33.2 \ --hash=sha256:fc2a0e8b05b180e5fc0dd1559fe8ebdae21a27e81ac77728fb6c42b12c7419b4 # via # -r python/python_requirements.in + # ai-edge-litert # tensorboard # tensorflow pycparser==2.23 \ @@ -1217,6 +1243,10 @@ termcolor==3.3.0 \ --hash=sha256:348871ca648ec6a9a983a13ab626c0acce02f515b9e1983332b17af7979521c5 \ --hash=sha256:cf642efadaf0a8ebbbf4bc7a31cec2f9b5f21a9f726f4ccbb08192c9c26f43a5 # via tensorflow +tqdm==4.67.3 \ + --hash=sha256:7d825f03f89244ef73f1d4ce193cb1774a8179fd96f31d7e1dcde62092b960bb \ + --hash=sha256:ee1e4c0e59148062281c49d80b25b67771a127c85fc9676d3be5f243206826bf + # via ai-edge-litert twine==6.2.0 \ --hash=sha256:418ebf08ccda9a8caaebe414433b0ba5e25eb5e4a927667122fbe8f829f985d8 \ --hash=sha256:e5ed0d2fd70c9959770dce51c8f39c8945c574e18173a7b81802dab51b4b75cf @@ -1225,6 +1255,7 @@ typing-extensions==4.15.0 \ --hash=sha256:0cea48d173cc12fa28ecabc3b837ea3cf6f38c6d1136f85cbaaf598984861466 \ --hash=sha256:f0fa19c6845758ab08074a0cfa8b7aecb71c999ca73d62883bc25cc018c4e548 # via + # ai-edge-litert # grpcio # optree # tensorflow diff --git a/tensorflow/lite/micro/tools/BUILD b/tensorflow/lite/micro/tools/BUILD index eecf1b4c0ad..ff0116c9d38 100644 --- a/tensorflow/lite/micro/tools/BUILD +++ b/tensorflow/lite/micro/tools/BUILD @@ -213,7 +213,7 @@ py_binary( ":layer_by_layer_schema_py", ":model_transforms_utils", requirement("absl_py"), - requirement("tensorflow"), + requirement("ai-edge-litert"), "//python/tflite_micro:runtime", "//tensorflow/lite/tools:flatbuffer_utils", ], diff --git a/tensorflow/lite/micro/tools/layer_by_layer_debugger.py b/tensorflow/lite/micro/tools/layer_by_layer_debugger.py index 1f7925bf38d..fc601f253bd 100644 --- a/tensorflow/lite/micro/tools/layer_by_layer_debugger.py +++ b/tensorflow/lite/micro/tools/layer_by_layer_debugger.py @@ -21,10 +21,30 @@ from absl import flags from absl import logging import numpy as np -import tensorflow as tf +OpResolverType = None +try: + import ai_edge_litert.interpreter as tflite_interp + try: + from ai_edge_litert.interpreter import OpResolverType + except ImportError: + pass +except ImportError: + try: + import tflite_runtime.interpreter as tflite_interp + try: + from tflite_runtime.interpreter import OpResolverType + except ImportError: + pass + except ImportError: + try: + import tensorflow.lite as tflite_interp + OpResolverType = tflite_interp.experimental.OpResolverType + except ImportError: + raise ImportError( + "Could not import ai_edge_litert, tflite_runtime, or tensorflow." + ) from tflite_micro.tensorflow.lite.tools import flatbuffer_utils -from tensorflow.python.platform import gfile from tflite_micro.python.tflite_micro import runtime from tflite_micro.tensorflow.lite.micro.tools import layer_by_layer_schema_py_generated as layer_schema_fb from tflite_micro.tensorflow.lite.micro.tools import model_transforms_utils @@ -160,7 +180,7 @@ def GenerateRandomInputTfLiteComparison(tflm_interpreter, tflite_interpreter, def ReadDebugFile(): - with gfile.GFile(_DEBUG_FILE.value, "rb") as debug_file_handle: + with open(_DEBUG_FILE.value, "rb") as debug_file_handle: debug_bytearray = bytearray(debug_file_handle.read()) flatbuffer_root_object = layer_schema_fb.ModelTestData.GetRootAs( debug_bytearray, 0) @@ -194,10 +214,18 @@ def main(_) -> None: intrepreter_config=runtime.InterpreterConfig.kPreserveAllTensors, ) - tflite_interpreter = tf.lite.Interpreter( - model_path=_INPUT_TFLITE_FILE.value, - experimental_preserve_all_tensors=True, - ) + kwargs = { + "model_path": _INPUT_TFLITE_FILE.value, + "experimental_preserve_all_tensors": True, + } + if OpResolverType is not None: + kwargs["experimental_op_resolver_type"] = OpResolverType.BUILTIN_REF + else: + logging.warning( + "Could not find OpResolverType. Reference kernels might not be used." + ) + + tflite_interpreter = tflite_interp.Interpreter(**kwargs) tflite_interpreter.allocate_tensors() diff --git a/tensorflow/lite/python/BUILD b/tensorflow/lite/python/BUILD index f7a662782ed..9ca2819a06b 100644 --- a/tensorflow/lite/python/BUILD +++ b/tensorflow/lite/python/BUILD @@ -20,6 +20,5 @@ py_library( visibility = ["//:__subpackages__"], deps = [ requirement("flatbuffers"), - requirement("tensorflow"), ], ) diff --git a/tensorflow/lite/python/schema_util.py b/tensorflow/lite/python/schema_util.py index e898a47318d..3f351f0caa4 100644 --- a/tensorflow/lite/python/schema_util.py +++ b/tensorflow/lite/python/schema_util.py @@ -14,7 +14,6 @@ # ============================================================================== """Schema utilities to get builtin code from operator code.""" -from tensorflow.python.util import all_util def get_builtin_code_from_operator_code(opcode): @@ -38,8 +37,3 @@ def get_builtin_code_from_operator_code(opcode): return max(opcode.builtinCode, opcode.deprecatedBuiltinCode) -_allowed_symbols = [ - 'get_builtin_code_from_operator_code', -] - -all_util.remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/lite/tools/BUILD b/tensorflow/lite/tools/BUILD index 0250959943d..a84dadde5a7 100644 --- a/tensorflow/lite/tools/BUILD +++ b/tensorflow/lite/tools/BUILD @@ -8,7 +8,6 @@ py_library( deps = [ "//:tflite_micro_shim", requirement("flatbuffers"), - requirement("tensorflow"), "//tensorflow/lite/python:schema_py", "//tensorflow/lite/python:schema_util", ], @@ -20,7 +19,6 @@ py_library( deps = [ "//:tflite_micro_shim", requirement("flatbuffers"), - requirement("tensorflow"), "//tensorflow/lite/python:schema_py", ], ) @@ -50,7 +48,6 @@ py_test( deps = [ ":flatbuffer_utils", ":test_utils", - requirement("tensorflow"), ], ) @@ -60,6 +57,5 @@ py_test( deps = [ ":test_utils", ":visualize", - requirement("tensorflow"), ], ) diff --git a/tensorflow/lite/tools/flatbuffer_utils.py b/tensorflow/lite/tools/flatbuffer_utils.py index 71e1afed6b5..a3495e18bf3 100644 --- a/tensorflow/lite/tools/flatbuffer_utils.py +++ b/tensorflow/lite/tools/flatbuffer_utils.py @@ -31,7 +31,7 @@ from tflite_micro.tensorflow.lite.python import schema_py_generated as schema_fb from tflite_micro.tensorflow.lite.python import schema_util -from tensorflow.python.platform import gfile +import os _TFLITE_FILE_IDENTIFIER = b'TFL3' @@ -55,9 +55,9 @@ def read_model(input_tflite_file): Returns: A python object corresponding to the input tflite file. """ - if not gfile.Exists(input_tflite_file): + if not os.path.exists(input_tflite_file): raise RuntimeError('Input file not found at %r\n' % input_tflite_file) - with gfile.GFile(input_tflite_file, 'rb') as input_file_handle: + with open(input_tflite_file, 'rb') as input_file_handle: model_bytearray = bytearray(input_file_handle.read()) return read_model_from_bytearray(model_bytearray) @@ -144,7 +144,7 @@ def write_model(model_object, output_tflite_file): model_object = copy.deepcopy(model_object) byte_swap_tflite_model_obj(model_object, 'big', 'little') model_bytearray = convert_object_to_bytearray(model_object) - with gfile.GFile(output_tflite_file, 'wb') as output_file_handle: + with open(output_tflite_file, 'wb') as output_file_handle: output_file_handle.write(model_bytearray) diff --git a/tensorflow/lite/tools/flatbuffer_utils_test.py b/tensorflow/lite/tools/flatbuffer_utils_test.py index 13074aaca5e..6fab6e7e832 100644 --- a/tensorflow/lite/tools/flatbuffer_utils_test.py +++ b/tensorflow/lite/tools/flatbuffer_utils_test.py @@ -21,20 +21,20 @@ from tflite_micro.tensorflow.lite.python import schema_py_generated as schema # pylint:disable=g-direct-tensorflow-import from tflite_micro.tensorflow.lite.tools import flatbuffer_utils from tflite_micro.tensorflow.lite.tools import test_utils -from tensorflow.python.framework import test_util -from tensorflow.python.platform import test +import unittest +import tempfile _SKIPPED_BUFFER_INDEX = 1 -class WriteReadModelTest(test_util.TensorFlowTestCase): +class WriteReadModelTest(unittest.TestCase): def testWriteReadModel(self): # 1. SETUP # Define the initial model initial_model = test_utils.build_mock_model() # Define temporary files - tmp_dir = self.get_temp_dir() + tmp_dir = tempfile.mkdtemp() model_filename = os.path.join(tmp_dir, 'model.tflite') # 2. INVOKE @@ -72,7 +72,7 @@ def testWriteReadModel(self): self.assertEqual(initial_buffer.data[i], final_buffer.data[i]) -class StripStringsTest(test_util.TensorFlowTestCase): +class StripStringsTest(unittest.TestCase): def testStripStrings(self): # 1. SETUP @@ -120,7 +120,7 @@ def testStripStrings(self): self.assertEqual(initial_buffer.data[i], final_buffer.data[i]) -class RandomizeWeightsTest(test_util.TensorFlowTestCase): +class RandomizeWeightsTest(unittest.TestCase): def testRandomizeWeights(self): # 1. SETUP @@ -204,7 +204,7 @@ def testRandomizeSomeWeights(self): self.assertEqual(initial_buffer.data[j], final_buffer.data[j]) -class XxdOutputToBytesTest(test_util.TensorFlowTestCase): +class XxdOutputToBytesTest(unittest.TestCase): def testXxdOutputToBytes(self): # 1. SETUP @@ -213,7 +213,7 @@ def testXxdOutputToBytes(self): initial_bytes = flatbuffer_utils.convert_object_to_bytearray(initial_model) # Define temporary files - tmp_dir = self.get_temp_dir() + tmp_dir = tempfile.mkdtemp() model_filename = os.path.join(tmp_dir, 'model.tflite') # 2. Write model to temporary file (will be used as input for xxd) @@ -236,7 +236,7 @@ def testXxdOutputToBytes(self): self.assertEqual(initial_bytes, final_bytes) -class CountResourceVariablesTest(test_util.TensorFlowTestCase): +class CountResourceVariablesTest(unittest.TestCase): def testCountResourceVariables(self): # 1. SETUP @@ -250,7 +250,7 @@ def testCountResourceVariables(self): flatbuffer_utils.count_resource_variables(initial_model), 1) -class GetOptionsTest(test_util.TensorFlowTestCase): +class GetOptionsTest(unittest.TestCase): op: schema.Operator op_t: schema.OperatorT @@ -290,4 +290,4 @@ def test_get_options_op_type_does_not_match(self): if __name__ == '__main__': - test.main() + unittest.main() diff --git a/tensorflow/lite/tools/visualize_test.py b/tensorflow/lite/tools/visualize_test.py index 68de38cc9d7..4deb5a7ccbb 100644 --- a/tensorflow/lite/tools/visualize_test.py +++ b/tensorflow/lite/tools/visualize_test.py @@ -18,11 +18,11 @@ from tflite_micro.tensorflow.lite.tools import test_utils from tflite_micro.tensorflow.lite.tools import visualize -from tensorflow.python.framework import test_util -from tensorflow.python.platform import test +import unittest +import tempfile -class VisualizeTest(test_util.TensorFlowTestCase): +class VisualizeTest(unittest.TestCase): def testTensorTypeToName(self): self.assertEqual('FLOAT32', visualize.TensorTypeToName(0)) @@ -42,7 +42,7 @@ def testFlatbufferToDict(self): def testVisualize(self): model = test_utils.build_mock_flatbuffer_model() - tmp_dir = self.get_temp_dir() + tmp_dir = tempfile.mkdtemp() model_filename = os.path.join(tmp_dir, 'model.tflite') with open(model_filename, 'wb') as model_file: model_file.write(model) @@ -61,4 +61,4 @@ def testVisualize(self): if __name__ == '__main__': - test.main() + unittest.main() From a70a7345cf71260d9ee8254955112ff3da129077 Mon Sep 17 00:00:00 2001 From: Esun Kim Date: Wed, 6 May 2026 10:07:32 -0700 Subject: [PATCH 2/5] Vendoring our own utils --- python/tflite_micro/BUILD | 42 +- python/tflite_micro/flatbuffer_utils.py | 537 ++++++++++++++++++ python/tflite_micro/flatbuffer_utils_test.py | 293 ++++++++++ python/tflite_micro/runtime.py | 2 +- python/tflite_micro/test_utils.py | 299 ++++++++++ python/tflite_micro/visualize_test.py | 64 +++ tensorflow/lite/python/BUILD | 1 + tensorflow/lite/python/schema_util.py | 6 + tensorflow/lite/tools/BUILD | 4 + tensorflow/lite/tools/flatbuffer_utils.py | 8 +- .../lite/tools/flatbuffer_utils_test.py | 22 +- tensorflow/lite/tools/visualize_test.py | 10 +- 12 files changed, 1265 insertions(+), 23 deletions(-) create mode 100644 python/tflite_micro/flatbuffer_utils.py create mode 100644 python/tflite_micro/flatbuffer_utils_test.py create mode 100644 python/tflite_micro/test_utils.py create mode 100644 python/tflite_micro/visualize_test.py diff --git a/python/tflite_micro/BUILD b/python/tflite_micro/BUILD index af6b91c3e05..9768b4a540d 100644 --- a/python/tflite_micro/BUILD +++ b/python/tflite_micro/BUILD @@ -71,6 +71,45 @@ pybind_extension( ], ) +py_library( + name = "flatbuffer_utils", + srcs = ["flatbuffer_utils.py"], + visibility = ["//visibility:public"], + deps = [ + "//:tflite_micro_shim", + requirement("flatbuffers"), + "//tensorflow/lite/python:schema_py", + ], +) + +py_library( + name = "test_utils", + srcs = ["test_utils.py"], + deps = [ + "//:tflite_micro_shim", + requirement("flatbuffers"), + "//tensorflow/lite/python:schema_py", + ], +) + +py_test( + name = "flatbuffer_utils_test", + srcs = ["flatbuffer_utils_test.py"], + deps = [ + ":flatbuffer_utils", + ":test_utils", + ], +) + +py_test( + name = "visualize_test", + srcs = ["visualize_test.py"], + deps = [ + ":test_utils", + "//tensorflow/lite/tools:visualize", + ], +) + py_library( name = "runtime", srcs = [ @@ -85,7 +124,7 @@ py_library( "//:tflite_micro_shim", requirement("numpy"), "//tensorflow/lite/micro/tools:generate_test_for_model", - "//tensorflow/lite/tools:flatbuffer_utils", + ":flatbuffer_utils", ], ) @@ -165,7 +204,6 @@ py_package( "tensorflow.lite.micro.tools", "tensorflow.lite.micro.tools.generate_test_for_model", "tensorflow.lite.python", - "tensorflow.lite.tools.flatbuffer_utils", ], deps = [ ":postinstall_check", diff --git a/python/tflite_micro/flatbuffer_utils.py b/python/tflite_micro/flatbuffer_utils.py new file mode 100644 index 00000000000..d20a8e3b0ab --- /dev/null +++ b/python/tflite_micro/flatbuffer_utils.py @@ -0,0 +1,537 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utility functions for FlatBuffers. + +All functions that are commonly used to work with FlatBuffers. + +Refer to the tensorflow lite flatbuffer schema here: +tensorflow/lite/schema/schema.fbs +""" + +import copy +import random +import re +import struct +import sys +from typing import Optional, Type, TypeVar, Union + +import flatbuffers + +from tflite_micro.tensorflow.lite.python import schema_py_generated as schema_fb +import os + + +def get_builtin_code_from_operator_code(opcode): + """Return the builtin code of the given operator code. + + The following method is introduced to resolve op builtin code shortage + problem. The new builtin operator will be assigned to the extended builtin + code field in the flatbuffer schema. Those methods helps to hide builtin code + details. + + Args: + opcode: Operator code. + + Returns: + The builtin code of the given operator code. + """ + # Access BuiltinCode() method first if available. + if hasattr(opcode, 'BuiltinCode') and callable(opcode.BuiltinCode): + return max(opcode.BuiltinCode(), opcode.DeprecatedBuiltinCode()) + + return max(opcode.builtinCode, opcode.deprecatedBuiltinCode) + +_TFLITE_FILE_IDENTIFIER = b'TFL3' + + +def convert_bytearray_to_object(model_bytearray): + """Converts a tflite model from a bytearray to an object for parsing.""" + model_object = schema_fb.Model.GetRootAsModel(model_bytearray, 0) + return schema_fb.ModelT.InitFromObj(model_object) + + +def read_model(input_tflite_file): + """Reads a tflite model as a python object. + + Args: + input_tflite_file: Full path name to the input tflite file + + Raises: + RuntimeError: If input_tflite_file path is invalid. + IOError: If input_tflite_file cannot be opened. + + Returns: + A python object corresponding to the input tflite file. + """ + if not os.path.exists(input_tflite_file): + raise RuntimeError('Input file not found at %r\n' % input_tflite_file) + with open(input_tflite_file, 'rb') as input_file_handle: + model_bytearray = bytearray(input_file_handle.read()) + return read_model_from_bytearray(model_bytearray) + + +def read_model_from_bytearray(model_bytearray): + """Reads a tflite model as a python object. + + Args: + model_bytearray: TFLite model in bytearray format. + + Returns: + A python object corresponding to the input tflite file. + """ + model = convert_bytearray_to_object(model_bytearray) + if sys.byteorder == 'big': + byte_swap_tflite_model_obj(model, 'little', 'big') + + # Offset handling for models > 2GB + for buffer in model.buffers: + if buffer.offset: + buffer.data = model_bytearray[buffer.offset : buffer.offset + buffer.size] + buffer.offset = 0 + buffer.size = 0 + for subgraph in model.subgraphs: + for op in subgraph.operators: + if op.largeCustomOptionsOffset: + op.customOptions = model_bytearray[ + op.largeCustomOptionsOffset : op.largeCustomOptionsOffset + + op.largeCustomOptionsSize + ] + op.largeCustomOptionsOffset = 0 + op.largeCustomOptionsSize = 0 + + return model + + +def read_model_with_mutable_tensors(input_tflite_file): + """Reads a tflite model as a python object with mutable tensors. + + Similar to read_model() with the addition that the returned object has + mutable tensors (read_model() returns an object with immutable tensors). + + NOTE: This API only works for TFLite generated with + _experimental_use_buffer_offset=false + + Args: + input_tflite_file: Full path name to the input tflite file + + Raises: + RuntimeError: If input_tflite_file path is invalid. + IOError: If input_tflite_file cannot be opened. + + Returns: + A mutable python object corresponding to the input tflite file. + """ + return copy.deepcopy(read_model(input_tflite_file)) + + +def convert_object_to_bytearray(model_object, extra_buffer=b''): + """Converts a tflite model from an object to a immutable bytearray.""" + # Initial size of the buffer, which will grow automatically if needed + builder = flatbuffers.Builder(1024) + model_offset = model_object.Pack(builder) + builder.Finish(model_offset, file_identifier=_TFLITE_FILE_IDENTIFIER) + model_bytearray = bytes(builder.Output()) + model_bytearray = model_bytearray + extra_buffer + return model_bytearray + + +def write_model(model_object, output_tflite_file): + """Writes the tflite model, a python object, into the output file. + + NOTE: This API only works for TFLite generated with + _experimental_use_buffer_offset=false + + Args: + model_object: A tflite model as a python object + output_tflite_file: Full path name to the output tflite file. + + Raises: + IOError: If output_tflite_file path is invalid or cannot be opened. + """ + if sys.byteorder == 'big': + model_object = copy.deepcopy(model_object) + byte_swap_tflite_model_obj(model_object, 'big', 'little') + model_bytearray = convert_object_to_bytearray(model_object) + with open(output_tflite_file, 'wb') as output_file_handle: + output_file_handle.write(model_bytearray) + + +def strip_strings(model): + """Strips all nonessential strings from the model to reduce model size. + + We remove the following strings: + (find strings by searching ":string" in the tensorflow lite flatbuffer schema) + 1. Model description + 2. SubGraph name + 3. Tensor names + We retain OperatorCode custom_code and Metadata name. + + Args: + model: The model from which to remove nonessential strings. + """ + + model.description = None + for subgraph in model.subgraphs: + subgraph.name = None + for tensor in subgraph.tensors: + tensor.name = None + # We clear all signature_def structure, since without names it is useless. + model.signatureDefs = None + + +def type_to_name(tensor_type): + """Converts a numerical enum to a readable tensor type.""" + for name, value in schema_fb.TensorType.__dict__.items(): + if value == tensor_type: + return name + return None + + +def randomize_weights(model, random_seed=0, buffers_to_skip=None): + """Randomize weights in a model. + + Args: + model: The model in which to randomize weights. + random_seed: The input to the random number generator (default value is 0). + buffers_to_skip: The list of buffer indices to skip. The weights in these + buffers are left unmodified. + """ + + # The input to the random seed generator. The default value is 0. + random.seed(random_seed) + + # Parse model buffers which store the model weights + buffers = model.buffers + buffer_ids = range(1, len(buffers)) # ignore index 0 as it's always None + if buffers_to_skip is not None: + buffer_ids = [idx for idx in buffer_ids if idx not in buffers_to_skip] + + buffer_types = {} + for graph in model.subgraphs: + for op in graph.operators: + if op.inputs is None: + break + for input_idx in op.inputs: + tensor = graph.tensors[input_idx] + buffer_types[tensor.buffer] = type_to_name(tensor.type) + + for i in buffer_ids: + buffer_i_data = buffers[i].data + buffer_i_size = 0 if buffer_i_data is None else buffer_i_data.size + if buffer_i_size == 0: + continue + + # Raw data buffers are of type ubyte (or uint8) whose values lie in the + # range [0, 255]. Those ubytes (or unint8s) are the underlying + # representation of each datatype. For example, a bias tensor of type + # int32 appears as a buffer 4 times it's length of type ubyte (or uint8). + # For floats, we need to generate a valid float and then pack it into + # the raw bytes in place. + buffer_type = buffer_types.get(i, 'INT8') + if buffer_type.startswith('FLOAT'): + format_code = 'e' if buffer_type == 'FLOAT16' else 'f' + for offset in range(0, buffer_i_size, struct.calcsize(format_code)): + value = random.uniform(-0.5, 0.5) # See http://b/152324470#comment2 + struct.pack_into(format_code, buffer_i_data, offset, value) + else: + for j in range(buffer_i_size): + buffer_i_data[j] = random.randint(0, 255) + + +def rename_custom_ops(model, map_custom_op_renames): + """Rename custom ops so they use the same naming style as builtin ops. + + Args: + model: The input tflite model. + map_custom_op_renames: A mapping from old to new custom op names. + """ + for op_code in model.operatorCodes: + if op_code.customCode: + op_code_str = op_code.customCode.decode('ascii') + if op_code_str in map_custom_op_renames: + op_code.customCode = map_custom_op_renames[op_code_str].encode('ascii') + + +def opcode_to_name(model, op_code): + """Converts a TFLite op_code to the human readable name. + + Args: + model: The input tflite model. + op_code: The op_code to resolve to a readable name. + + Returns: + A string containing the human readable op name, or None if not resolvable. + """ + op = model.operatorCodes[op_code] + code = max(op.builtinCode, op.deprecatedBuiltinCode) + for name, value in vars(schema_fb.BuiltinOperator).items(): + if value == code: + return name + return None + + +def xxd_output_to_bytes(input_cc_file): + """Converts xxd output C++ source file to bytes (immutable). + + Args: + input_cc_file: Full path name to th C++ source file dumped by xxd + + Raises: + RuntimeError: If input_cc_file path is invalid. + IOError: If input_cc_file cannot be opened. + + Returns: + A bytearray corresponding to the input cc file array. + """ + # Match hex values in the string with comma as separator + pattern = re.compile(r'\W*(0x[0-9a-fA-F,x ]+).*') + + model_bytearray = bytearray() + + with open(input_cc_file) as file_handle: + for line in file_handle: + values_match = pattern.match(line) + + if values_match is None: + continue + + # Match in the parentheses (hex array only) + list_text = values_match.group(1) + + # Extract hex values (text) from the line + # e.g. 0x1c, 0x00, 0x00, 0x00, 0x54, 0x46, 0x4c, + values_text = filter(None, list_text.split(',')) + + # Convert to hex + values = [int(x, base=16) for x in values_text] + model_bytearray.extend(values) + + return bytes(model_bytearray) + + +def xxd_output_to_object(input_cc_file): + """Converts xxd output C++ source file to object. + + Args: + input_cc_file: Full path name to th C++ source file dumped by xxd + + Raises: + RuntimeError: If input_cc_file path is invalid. + IOError: If input_cc_file cannot be opened. + + Returns: + A python object corresponding to the input tflite file. + """ + model_bytes = xxd_output_to_bytes(input_cc_file) + return convert_bytearray_to_object(model_bytes) + + +def byte_swap_buffer_content(buffer, chunksize, from_endiness, to_endiness): + """Helper function for byte-swapping the buffers field.""" + to_swap = [ + buffer.data[i : i + chunksize] + for i in range(0, len(buffer.data), chunksize) + ] + buffer.data = b''.join([ + int.from_bytes(byteswap, from_endiness).to_bytes(chunksize, to_endiness) + for byteswap in to_swap + ]) + + +def byte_swap_string_content(buffer, from_endiness, to_endiness): + """Helper function for byte-swapping the string buffer. + + Args: + buffer: TFLite string buffer of from_endiness format. + from_endiness: The original endianness format of the string buffer. + to_endiness: The destined endianness format of the string buffer. + """ + num_of_strings = int.from_bytes(buffer.data[0:4], from_endiness) + string_content = bytearray(buffer.data[4 * (num_of_strings + 2) :]) + prefix_data = b''.join([ + int.from_bytes(buffer.data[i : i + 4], from_endiness).to_bytes( + 4, to_endiness + ) + for i in range(0, (num_of_strings + 1) * 4 + 1, 4) + ]) + buffer.data = prefix_data + string_content + + +def byte_swap_tflite_model_obj(model, from_endiness, to_endiness): + """Byte swaps the buffers field in a TFLite model. + + Args: + model: TFLite model object of from_endiness format. + from_endiness: The original endianness format of the buffers in model. + to_endiness: The destined endianness format of the buffers in model. + """ + if model is None: + return + # Get all the constant buffers, byte swapping them as per their data types + buffer_swapped = [] + types_of_16_bits = [ + schema_fb.TensorType.FLOAT16, + schema_fb.TensorType.INT16, + schema_fb.TensorType.UINT16, + ] + types_of_32_bits = [ + schema_fb.TensorType.FLOAT32, + schema_fb.TensorType.INT32, + schema_fb.TensorType.COMPLEX64, + schema_fb.TensorType.UINT32, + ] + types_of_64_bits = [ + schema_fb.TensorType.INT64, + schema_fb.TensorType.FLOAT64, + schema_fb.TensorType.COMPLEX128, + schema_fb.TensorType.UINT64, + ] + for subgraph in model.subgraphs: + for tensor in subgraph.tensors: + if ( + tensor.buffer > 0 + and tensor.buffer < len(model.buffers) + and tensor.buffer not in buffer_swapped + and model.buffers[tensor.buffer].data is not None + ): + if tensor.type == schema_fb.TensorType.STRING: + byte_swap_string_content( + model.buffers[tensor.buffer], from_endiness, to_endiness + ) + elif tensor.type in types_of_16_bits: + byte_swap_buffer_content( + model.buffers[tensor.buffer], 2, from_endiness, to_endiness + ) + elif tensor.type in types_of_32_bits: + byte_swap_buffer_content( + model.buffers[tensor.buffer], 4, from_endiness, to_endiness + ) + elif tensor.type in types_of_64_bits: + byte_swap_buffer_content( + model.buffers[tensor.buffer], 8, from_endiness, to_endiness + ) + else: + continue + buffer_swapped.append(tensor.buffer) + + +def byte_swap_tflite_buffer(tflite_model, from_endiness, to_endiness): + """Generates a new model byte array after byte swapping its buffers field. + + Args: + tflite_model: TFLite flatbuffer in a byte array. + from_endiness: The original endianness format of the buffers in + tflite_model. + to_endiness: The destined endianness format of the buffers in tflite_model. + + Returns: + TFLite flatbuffer in a byte array, after being byte swapped to to_endiness + format. + """ + if tflite_model is None: + return None + # Load TFLite Flatbuffer byte array into an object. + model = convert_bytearray_to_object(tflite_model) + + # Byte swapping the constant buffers as per their data types + byte_swap_tflite_model_obj(model, from_endiness, to_endiness) + + # Return a TFLite flatbuffer as a byte array. + return convert_object_to_bytearray(model) + + +def count_resource_variables(model): + """Calculates the number of unique resource variables in a model. + + Args: + model: the input tflite model, either as bytearray or object. + + Returns: + An integer number representing the number of unique resource variables. + """ + if not isinstance(model, schema_fb.ModelT): + model = convert_bytearray_to_object(model) + unique_shared_names = set() + for subgraph in model.subgraphs: + if subgraph.operators is None: + continue + for op in subgraph.operators: + builtin_code = get_builtin_code_from_operator_code( + model.operatorCodes[op.opcodeIndex] + ) + if builtin_code == schema_fb.BuiltinOperator.VAR_HANDLE: + unique_shared_names.add(op.builtinOptions.sharedName) + return len(unique_shared_names) + + +OptsT = TypeVar('OptsT') + + +def get_options_as( + op: Union[schema_fb.Operator, schema_fb.OperatorT], opts_type: Type[OptsT] +) -> Optional[OptsT]: + """Get the options of an operator as the specified type. + + Requested type must be an object-api type (ends in 'T'). + + Args: + op: The operator to get the options from. + opts_type: The type of the options to get. + + Returns: + The options as the specified type, or None if the options are not of the + specified type. + + Raises: + ValueError: If the specified type is not a valid options type. + """ + + err = ValueError(f'Unsupported options type: {opts_type}') + type_name: str = opts_type.__name__ + if not type_name.endswith('T'): + raise err + base_type_name = type_name.removesuffix('T') + is_opt_1_type = hasattr(schema_fb.BuiltinOptions, base_type_name) + if not is_opt_1_type and not hasattr( + schema_fb.BuiltinOptions2, base_type_name + ): + raise err + + if isinstance(op, schema_fb.Operator): + if not is_opt_1_type: + enum_val = getattr(schema_fb.BuiltinOptions2, base_type_name) + opts_creator = schema_fb.BuiltinOptions2Creator + raw_ops = op.BuiltinOptions2() + actual_enum_val = op.BuiltinOptions2Type() + else: + enum_val = getattr(schema_fb.BuiltinOptions, base_type_name) + opts_creator = schema_fb.BuiltinOptionsCreator + raw_ops = op.BuiltinOptions() + actual_enum_val = op.BuiltinOptionsType() + if raw_ops is None or actual_enum_val != enum_val: + return None + return opts_creator(enum_val, raw_ops) + + elif isinstance(op, schema_fb.OperatorT): + if is_opt_1_type: + raw_ops_t = op.builtinOptions + else: + raw_ops_t = op.builtinOptions2 + if raw_ops_t is None or not isinstance(raw_ops_t, opts_type): + return None + return raw_ops_t + + else: + return None diff --git a/python/tflite_micro/flatbuffer_utils_test.py b/python/tflite_micro/flatbuffer_utils_test.py new file mode 100644 index 00000000000..75f5b61e598 --- /dev/null +++ b/python/tflite_micro/flatbuffer_utils_test.py @@ -0,0 +1,293 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for flatbuffer_utils.py.""" +import copy +import os +import subprocess +import sys + +from tflite_micro.tensorflow.lite.python import schema_py_generated as schema # pylint:disable=g-direct-tensorflow-import +from tflite_micro.python.tflite_micro import flatbuffer_utils +from tflite_micro.python.tflite_micro import test_utils +import unittest +import tempfile + +_SKIPPED_BUFFER_INDEX = 1 + + +class WriteReadModelTest(unittest.TestCase): + + def testWriteReadModel(self): + # 1. SETUP + # Define the initial model + initial_model = test_utils.build_mock_model() + # Define temporary files + tmp_dir = tempfile.mkdtemp() + model_filename = os.path.join(tmp_dir, 'model.tflite') + + # 2. INVOKE + # Invoke the write_model and read_model functions + flatbuffer_utils.write_model(initial_model, model_filename) + final_model = flatbuffer_utils.read_model(model_filename) + + # 3. VALIDATE + # Validate that the initial and final models are the same + # Validate the description + self.assertEqual(initial_model.description, final_model.description) + # Validate the main subgraph's name, inputs, outputs, operators and tensors + initial_subgraph = initial_model.subgraphs[0] + final_subgraph = final_model.subgraphs[0] + self.assertEqual(initial_subgraph.name, final_subgraph.name) + for i in range(len(initial_subgraph.inputs)): + self.assertEqual(initial_subgraph.inputs[i], final_subgraph.inputs[i]) + for i in range(len(initial_subgraph.outputs)): + self.assertEqual(initial_subgraph.outputs[i], final_subgraph.outputs[i]) + for i in range(len(initial_subgraph.operators)): + self.assertEqual(initial_subgraph.operators[i].opcodeIndex, + final_subgraph.operators[i].opcodeIndex) + initial_tensors = initial_subgraph.tensors + final_tensors = final_subgraph.tensors + for i in range(len(initial_tensors)): + self.assertEqual(initial_tensors[i].name, final_tensors[i].name) + self.assertEqual(initial_tensors[i].type, final_tensors[i].type) + self.assertEqual(initial_tensors[i].buffer, final_tensors[i].buffer) + for j in range(len(initial_tensors[i].shape)): + self.assertEqual(initial_tensors[i].shape[j], final_tensors[i].shape[j]) + # Validate the first valid buffer (index 0 is always None) + initial_buffer = initial_model.buffers[1].data + final_buffer = final_model.buffers[1].data + for i in range(initial_buffer.size): + self.assertEqual(initial_buffer.data[i], final_buffer.data[i]) + + +class StripStringsTest(unittest.TestCase): + + def testStripStrings(self): + # 1. SETUP + # Define the initial model + initial_model = test_utils.build_mock_model() + final_model = copy.deepcopy(initial_model) + + # 2. INVOKE + # Invoke the strip_strings function + flatbuffer_utils.strip_strings(final_model) + + # 3. VALIDATE + # Validate that the initial and final models are the same except strings + # Validate the description + self.assertIsNotNone(initial_model.description) + self.assertIsNone(final_model.description) + self.assertIsNotNone(initial_model.signatureDefs) + self.assertIsNone(final_model.signatureDefs) + + # Validate the main subgraph's name, inputs, outputs, operators and tensors + initial_subgraph = initial_model.subgraphs[0] + final_subgraph = final_model.subgraphs[0] + self.assertIsNotNone(initial_model.subgraphs[0].name) + self.assertIsNone(final_model.subgraphs[0].name) + for i in range(len(initial_subgraph.inputs)): + self.assertEqual(initial_subgraph.inputs[i], final_subgraph.inputs[i]) + for i in range(len(initial_subgraph.outputs)): + self.assertEqual(initial_subgraph.outputs[i], final_subgraph.outputs[i]) + for i in range(len(initial_subgraph.operators)): + self.assertEqual(initial_subgraph.operators[i].opcodeIndex, + final_subgraph.operators[i].opcodeIndex) + initial_tensors = initial_subgraph.tensors + final_tensors = final_subgraph.tensors + for i in range(len(initial_tensors)): + self.assertIsNotNone(initial_tensors[i].name) + self.assertIsNone(final_tensors[i].name) + self.assertEqual(initial_tensors[i].type, final_tensors[i].type) + self.assertEqual(initial_tensors[i].buffer, final_tensors[i].buffer) + for j in range(len(initial_tensors[i].shape)): + self.assertEqual(initial_tensors[i].shape[j], final_tensors[i].shape[j]) + # Validate the first valid buffer (index 0 is always None) + initial_buffer = initial_model.buffers[1].data + final_buffer = final_model.buffers[1].data + for i in range(initial_buffer.size): + self.assertEqual(initial_buffer.data[i], final_buffer.data[i]) + + +class RandomizeWeightsTest(unittest.TestCase): + + def testRandomizeWeights(self): + # 1. SETUP + # Define the initial model + initial_model = test_utils.build_mock_model() + final_model = copy.deepcopy(initial_model) + + # 2. INVOKE + # Invoke the randomize_weights function + flatbuffer_utils.randomize_weights(final_model) + + # 3. VALIDATE + # Validate that the initial and final models are the same, except that + # the weights in the model buffer have been modified (i.e, randomized) + # Validate the description + self.assertEqual(initial_model.description, final_model.description) + # Validate the main subgraph's name, inputs, outputs, operators and tensors + initial_subgraph = initial_model.subgraphs[0] + final_subgraph = final_model.subgraphs[0] + self.assertEqual(initial_subgraph.name, final_subgraph.name) + for i in range(len(initial_subgraph.inputs)): + self.assertEqual(initial_subgraph.inputs[i], final_subgraph.inputs[i]) + for i in range(len(initial_subgraph.outputs)): + self.assertEqual(initial_subgraph.outputs[i], final_subgraph.outputs[i]) + for i in range(len(initial_subgraph.operators)): + self.assertEqual(initial_subgraph.operators[i].opcodeIndex, + final_subgraph.operators[i].opcodeIndex) + initial_tensors = initial_subgraph.tensors + final_tensors = final_subgraph.tensors + for i in range(len(initial_tensors)): + self.assertEqual(initial_tensors[i].name, final_tensors[i].name) + self.assertEqual(initial_tensors[i].type, final_tensors[i].type) + self.assertEqual(initial_tensors[i].buffer, final_tensors[i].buffer) + for j in range(len(initial_tensors[i].shape)): + self.assertEqual(initial_tensors[i].shape[j], final_tensors[i].shape[j]) + # Validate the first valid buffer (index 0 is always None) + initial_buffer = initial_model.buffers[1].data + final_buffer = final_model.buffers[1].data + for j in range(initial_buffer.size): + self.assertNotEqual(initial_buffer.data[j], final_buffer.data[j]) + + def testRandomizeSomeWeights(self): + # 1. SETUP + # Define the initial model + initial_model = test_utils.build_mock_model() + final_model = copy.deepcopy(initial_model) + + # 2. INVOKE + # Invoke the randomize_weights function, but skip the first buffer + flatbuffer_utils.randomize_weights( + final_model, buffers_to_skip=[_SKIPPED_BUFFER_INDEX]) + + # 3. VALIDATE + # Validate that the initial and final models are the same, except that + # the weights in the model buffer have been modified (i.e, randomized) + # Validate the description + self.assertEqual(initial_model.description, final_model.description) + # Validate the main subgraph's name, inputs, outputs, operators and tensors + initial_subgraph = initial_model.subgraphs[0] + final_subgraph = final_model.subgraphs[0] + self.assertEqual(initial_subgraph.name, final_subgraph.name) + for i, _ in enumerate(initial_subgraph.inputs): + self.assertEqual(initial_subgraph.inputs[i], final_subgraph.inputs[i]) + for i, _ in enumerate(initial_subgraph.outputs): + self.assertEqual(initial_subgraph.outputs[i], final_subgraph.outputs[i]) + for i, _ in enumerate(initial_subgraph.operators): + self.assertEqual(initial_subgraph.operators[i].opcodeIndex, + final_subgraph.operators[i].opcodeIndex) + initial_tensors = initial_subgraph.tensors + final_tensors = final_subgraph.tensors + for i, _ in enumerate(initial_tensors): + self.assertEqual(initial_tensors[i].name, final_tensors[i].name) + self.assertEqual(initial_tensors[i].type, final_tensors[i].type) + self.assertEqual(initial_tensors[i].buffer, final_tensors[i].buffer) + for j in range(len(initial_tensors[i].shape)): + self.assertEqual(initial_tensors[i].shape[j], final_tensors[i].shape[j]) + # Validate that the skipped buffer is unchanged. + initial_buffer = initial_model.buffers[_SKIPPED_BUFFER_INDEX].data + final_buffer = final_model.buffers[_SKIPPED_BUFFER_INDEX].data + for j in range(initial_buffer.size): + self.assertEqual(initial_buffer.data[j], final_buffer.data[j]) + + +class XxdOutputToBytesTest(unittest.TestCase): + + def testXxdOutputToBytes(self): + # 1. SETUP + # Define the initial model + initial_model = test_utils.build_mock_model() + initial_bytes = flatbuffer_utils.convert_object_to_bytearray(initial_model) + + # Define temporary files + tmp_dir = tempfile.mkdtemp() + model_filename = os.path.join(tmp_dir, 'model.tflite') + + # 2. Write model to temporary file (will be used as input for xxd) + flatbuffer_utils.write_model(initial_model, model_filename) + + # 3. DUMP WITH xxd + input_cc_file = os.path.join(tmp_dir, 'model.cc') + + command = 'xxd -i {} > {}'.format(model_filename, input_cc_file) + subprocess.call(command, shell=True) + + # 4. VALIDATE + final_bytes = flatbuffer_utils.xxd_output_to_bytes(input_cc_file) + if sys.byteorder == 'big': + final_bytes = flatbuffer_utils.byte_swap_tflite_buffer( + final_bytes, 'little', 'big' + ) + + # Validate that the initial and final bytearray are the same + self.assertEqual(initial_bytes, final_bytes) + + +class CountResourceVariablesTest(unittest.TestCase): + + def testCountResourceVariables(self): + # 1. SETUP + # Define the initial model + initial_model = test_utils.build_mock_model() + + # 2. Confirm that resource variables for mock model is 1 + # The mock model is created with two VAR HANDLE ops, but with the same + # shared name. + self.assertEqual( + flatbuffer_utils.count_resource_variables(initial_model), 1) + + +class GetOptionsTest(unittest.TestCase): + + op: schema.Operator + op_t: schema.OperatorT + + @classmethod + def setUpClass(cls): + super().setUpClass() + cls.op = test_utils.build_operator_with_options() + cls.op_t = schema.OperatorT.InitFromObj(cls.op) + + def test_get_options(self): + ty = schema.StableHLOCompositeOptionsT + opts = flatbuffer_utils.get_options_as(self.op, ty) + self.assertIsNotNone(opts) + self.assertIsInstance(opts, ty) + self.assertEqual(opts.decompositionSubgraphIndex, 10) + + def test_get_options_obj(self): + ty = schema.StableHLOCompositeOptionsT + opts = flatbuffer_utils.get_options_as(self.op_t, ty) + self.assertIsNotNone(opts) + self.assertIsInstance(opts, ty) + self.assertEqual(opts.decompositionSubgraphIndex, 10) + + def test_get_options_not_schema_type_raises(self): + with self.assertRaises(ValueError): + flatbuffer_utils.get_options_as(self.op, int) + + def test_get_options_not_object_type_raises(self): + with self.assertRaises(ValueError): + flatbuffer_utils.get_options_as(self.op, schema.StableHLOCompositeOptions) + + def test_get_options_op_type_does_not_match(self): + ty = schema.Conv2DOptionsT + opts = flatbuffer_utils.get_options_as(self.op, ty) + self.assertIsNone(opts) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/tflite_micro/runtime.py b/python/tflite_micro/runtime.py index 6d91b190a32..8ef04bdc109 100644 --- a/python/tflite_micro/runtime.py +++ b/python/tflite_micro/runtime.py @@ -16,7 +16,7 @@ import enum import os -from tflite_micro.tensorflow.lite.tools import flatbuffer_utils +from tflite_micro.python.tflite_micro import flatbuffer_utils from tflite_micro.python.tflite_micro import _runtime diff --git a/python/tflite_micro/test_utils.py b/python/tflite_micro/test_utils.py new file mode 100644 index 00000000000..44157143d5d --- /dev/null +++ b/python/tflite_micro/test_utils.py @@ -0,0 +1,299 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utility functions that support testing. + +All functions that can be commonly used by various tests. +""" + +import flatbuffers +from tflite_micro.tensorflow.lite.python import schema_py_generated as schema_fb + +TFLITE_SCHEMA_VERSION = 3 + + +def build_mock_flatbuffer_model(): + """Creates a flatbuffer containing an example model.""" + builder = flatbuffers.Builder(1024) + + schema_fb.BufferStart(builder) + buffer0_offset = schema_fb.BufferEnd(builder) + + schema_fb.BufferStartDataVector(builder, 12) + builder.PrependUint8(11) + builder.PrependUint8(10) + builder.PrependUint8(9) + builder.PrependUint8(8) + builder.PrependUint8(7) + builder.PrependUint8(6) + builder.PrependUint8(5) + builder.PrependUint8(4) + builder.PrependUint8(3) + builder.PrependUint8(2) + builder.PrependUint8(1) + builder.PrependUint8(0) + buffer1_data_offset = builder.EndVector() + schema_fb.BufferStart(builder) + schema_fb.BufferAddData(builder, buffer1_data_offset) + buffer1_offset = schema_fb.BufferEnd(builder) + + schema_fb.BufferStart(builder) + buffer2_offset = schema_fb.BufferEnd(builder) + + schema_fb.ModelStartBuffersVector(builder, 3) + builder.PrependUOffsetTRelative(buffer2_offset) + builder.PrependUOffsetTRelative(buffer1_offset) + builder.PrependUOffsetTRelative(buffer0_offset) + buffers_offset = builder.EndVector() + + string0_offset = builder.CreateString('input_tensor') + schema_fb.TensorStartShapeVector(builder, 3) + builder.PrependInt32(1) + builder.PrependInt32(2) + builder.PrependInt32(5) + shape0_offset = builder.EndVector() + schema_fb.TensorStart(builder) + schema_fb.TensorAddName(builder, string0_offset) + schema_fb.TensorAddShape(builder, shape0_offset) + schema_fb.TensorAddType(builder, 0) + schema_fb.TensorAddBuffer(builder, 0) + tensor0_offset = schema_fb.TensorEnd(builder) + + schema_fb.QuantizationParametersStartMinVector(builder, 5) + builder.PrependFloat32(0.5) + builder.PrependFloat32(2.0) + builder.PrependFloat32(5.0) + builder.PrependFloat32(10.0) + builder.PrependFloat32(20.0) + quant1_min_offset = builder.EndVector() + + schema_fb.QuantizationParametersStartMaxVector(builder, 5) + builder.PrependFloat32(10.0) + builder.PrependFloat32(20.0) + builder.PrependFloat32(-50.0) + builder.PrependFloat32(1.0) + builder.PrependFloat32(2.0) + quant1_max_offset = builder.EndVector() + + schema_fb.QuantizationParametersStartScaleVector(builder, 5) + builder.PrependFloat32(3.0) + builder.PrependFloat32(4.0) + builder.PrependFloat32(5.0) + builder.PrependFloat32(6.0) + builder.PrependFloat32(7.0) + quant1_scale_offset = builder.EndVector() + + schema_fb.QuantizationParametersStartZeroPointVector(builder, 5) + builder.PrependInt64(1) + builder.PrependInt64(2) + builder.PrependInt64(3) + builder.PrependInt64(-1) + builder.PrependInt64(-2) + quant1_zero_point_offset = builder.EndVector() + + schema_fb.QuantizationParametersStart(builder) + schema_fb.QuantizationParametersAddMin(builder, quant1_min_offset) + schema_fb.QuantizationParametersAddMax(builder, quant1_max_offset) + schema_fb.QuantizationParametersAddScale(builder, quant1_scale_offset) + schema_fb.QuantizationParametersAddZeroPoint(builder, + quant1_zero_point_offset) + quantization1_offset = schema_fb.QuantizationParametersEnd(builder) + + string1_offset = builder.CreateString('constant_tensor') + schema_fb.TensorStartShapeVector(builder, 3) + builder.PrependInt32(1) + builder.PrependInt32(2) + builder.PrependInt32(5) + shape1_offset = builder.EndVector() + schema_fb.TensorStart(builder) + schema_fb.TensorAddName(builder, string1_offset) + schema_fb.TensorAddShape(builder, shape1_offset) + schema_fb.TensorAddType(builder, schema_fb.TensorType.UINT8) + schema_fb.TensorAddBuffer(builder, 1) + schema_fb.TensorAddQuantization(builder, quantization1_offset) + tensor1_offset = schema_fb.TensorEnd(builder) + + string2_offset = builder.CreateString('output_tensor') + schema_fb.TensorStartShapeVector(builder, 3) + builder.PrependInt32(1) + builder.PrependInt32(2) + builder.PrependInt32(5) + shape2_offset = builder.EndVector() + schema_fb.TensorStart(builder) + schema_fb.TensorAddName(builder, string2_offset) + schema_fb.TensorAddShape(builder, shape2_offset) + schema_fb.TensorAddType(builder, 0) + schema_fb.TensorAddBuffer(builder, 2) + tensor2_offset = schema_fb.TensorEnd(builder) + + schema_fb.SubGraphStartTensorsVector(builder, 3) + builder.PrependUOffsetTRelative(tensor2_offset) + builder.PrependUOffsetTRelative(tensor1_offset) + builder.PrependUOffsetTRelative(tensor0_offset) + tensors_offset = builder.EndVector() + + schema_fb.SubGraphStartInputsVector(builder, 1) + builder.PrependInt32(0) + inputs_offset = builder.EndVector() + + schema_fb.SubGraphStartOutputsVector(builder, 1) + builder.PrependInt32(2) + outputs_offset = builder.EndVector() + + schema_fb.OperatorCodeStart(builder) + schema_fb.OperatorCodeAddBuiltinCode(builder, schema_fb.BuiltinOperator.ADD) + schema_fb.OperatorCodeAddDeprecatedBuiltinCode(builder, + schema_fb.BuiltinOperator.ADD) + schema_fb.OperatorCodeAddVersion(builder, 1) + code0_offset = schema_fb.OperatorCodeEnd(builder) + + schema_fb.OperatorCodeStart(builder) + schema_fb.OperatorCodeAddBuiltinCode(builder, + schema_fb.BuiltinOperator.VAR_HANDLE) + schema_fb.OperatorCodeAddDeprecatedBuiltinCode( + builder, schema_fb.BuiltinOperator.PLACEHOLDER_FOR_GREATER_OP_CODES) + schema_fb.OperatorCodeAddVersion(builder, 1) + code1_offset = schema_fb.OperatorCodeEnd(builder) + + schema_fb.ModelStartOperatorCodesVector(builder, 2) + builder.PrependUOffsetTRelative(code1_offset) + builder.PrependUOffsetTRelative(code0_offset) + codes_offset = builder.EndVector() + + schema_fb.OperatorStartInputsVector(builder, 2) + builder.PrependInt32(0) + builder.PrependInt32(1) + op_inputs_offset = builder.EndVector() + + schema_fb.OperatorStartOutputsVector(builder, 1) + builder.PrependInt32(2) + op_outputs_offset = builder.EndVector() + + schema_fb.OperatorStart(builder) + schema_fb.OperatorAddOpcodeIndex(builder, 0) + schema_fb.OperatorAddInputs(builder, op_inputs_offset) + schema_fb.OperatorAddOutputs(builder, op_outputs_offset) + op0_offset = schema_fb.OperatorEnd(builder) + + shared_name = builder.CreateString('var') + schema_fb.VarHandleOptionsStart(builder) + schema_fb.VarHandleOptionsAddSharedName(builder, shared_name) + var_handle_options_offset = schema_fb.VarHandleOptionsEnd(builder) + + schema_fb.OperatorStart(builder) + schema_fb.OperatorAddOpcodeIndex(builder, 1) + schema_fb.OperatorAddBuiltinOptionsType( + builder, schema_fb.BuiltinOptions.VarHandleOptions) + schema_fb.OperatorAddBuiltinOptions(builder, var_handle_options_offset) + op1_offset = schema_fb.OperatorEnd(builder) + + schema_fb.OperatorStart(builder) + schema_fb.OperatorAddBuiltinOptionsType( + builder, schema_fb.BuiltinOptions.VarHandleOptions) + schema_fb.OperatorAddBuiltinOptions(builder, var_handle_options_offset) + op2_offset = schema_fb.OperatorEnd(builder) + + schema_fb.SubGraphStartOperatorsVector(builder, 3) + builder.PrependUOffsetTRelative(op2_offset) + builder.PrependUOffsetTRelative(op1_offset) + builder.PrependUOffsetTRelative(op0_offset) + ops_offset = builder.EndVector() + + string3_offset = builder.CreateString('subgraph_name') + schema_fb.SubGraphStart(builder) + schema_fb.SubGraphAddName(builder, string3_offset) + schema_fb.SubGraphAddTensors(builder, tensors_offset) + schema_fb.SubGraphAddInputs(builder, inputs_offset) + schema_fb.SubGraphAddOutputs(builder, outputs_offset) + schema_fb.SubGraphAddOperators(builder, ops_offset) + subgraph_offset = schema_fb.SubGraphEnd(builder) + + schema_fb.ModelStartSubgraphsVector(builder, 1) + builder.PrependUOffsetTRelative(subgraph_offset) + subgraphs_offset = builder.EndVector() + + signature_key = builder.CreateString('my_key') + input_tensor_string = builder.CreateString('input_tensor') + output_tensor_string = builder.CreateString('output_tensor') + + # Signature Inputs + schema_fb.TensorMapStart(builder) + schema_fb.TensorMapAddName(builder, input_tensor_string) + schema_fb.TensorMapAddTensorIndex(builder, 1) + input_tensor = schema_fb.TensorMapEnd(builder) + + # Signature Outputs + schema_fb.TensorMapStart(builder) + schema_fb.TensorMapAddName(builder, output_tensor_string) + schema_fb.TensorMapAddTensorIndex(builder, 2) + output_tensor = schema_fb.TensorMapEnd(builder) + + schema_fb.SignatureDefStartInputsVector(builder, 1) + builder.PrependUOffsetTRelative(input_tensor) + signature_inputs_offset = builder.EndVector() + schema_fb.SignatureDefStartOutputsVector(builder, 1) + builder.PrependUOffsetTRelative(output_tensor) + signature_outputs_offset = builder.EndVector() + + schema_fb.SignatureDefStart(builder) + schema_fb.SignatureDefAddSignatureKey(builder, signature_key) + schema_fb.SignatureDefAddInputs(builder, signature_inputs_offset) + schema_fb.SignatureDefAddOutputs(builder, signature_outputs_offset) + signature_offset = schema_fb.SignatureDefEnd(builder) + schema_fb.ModelStartSignatureDefsVector(builder, 1) + builder.PrependUOffsetTRelative(signature_offset) + signature_defs_offset = builder.EndVector() + + string4_offset = builder.CreateString('model_description') + schema_fb.ModelStart(builder) + schema_fb.ModelAddVersion(builder, TFLITE_SCHEMA_VERSION) + schema_fb.ModelAddOperatorCodes(builder, codes_offset) + schema_fb.ModelAddSubgraphs(builder, subgraphs_offset) + schema_fb.ModelAddDescription(builder, string4_offset) + schema_fb.ModelAddBuffers(builder, buffers_offset) + schema_fb.ModelAddSignatureDefs(builder, signature_defs_offset) + model_offset = schema_fb.ModelEnd(builder) + builder.Finish(model_offset) + model = builder.Output() + + return model + + +def build_operator_with_options() -> schema_fb.Operator: + """Builds an operator with the given options.""" + builder = flatbuffers.Builder(1024) + schema_fb.StableHLOCompositeOptionsStart(builder) + schema_fb.StableHLOCompositeOptionsAddDecompositionSubgraphIndex(builder, 10) + opts = schema_fb.StableHLOCompositeOptionsEnd(builder) + schema_fb.OperatorStart(builder) + schema_fb.OperatorAddBuiltinOptions2(builder, opts) + schema_fb.OperatorAddBuiltinOptions2Type( + builder, schema_fb.BuiltinOptions2.StableHLOCompositeOptions + ) + op_offset = schema_fb.OperatorEnd(builder) + builder.Finish(op_offset) + return schema_fb.Operator.GetRootAs(builder.Output()) + + +def load_model_from_flatbuffer(flatbuffer_model): + """Loads a model as a python object from a flatbuffer model.""" + model = schema_fb.Model.GetRootAsModel(flatbuffer_model, 0) + model = schema_fb.ModelT.InitFromObj(model) + return model + + +def build_mock_model(): + """Creates an object containing an example model.""" + model = build_mock_flatbuffer_model() + return load_model_from_flatbuffer(model) diff --git a/python/tflite_micro/visualize_test.py b/python/tflite_micro/visualize_test.py new file mode 100644 index 00000000000..5c2fa31cff5 --- /dev/null +++ b/python/tflite_micro/visualize_test.py @@ -0,0 +1,64 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""TensorFlow Lite Python Interface: Sanity check.""" +import os +import re + +from tflite_micro.python.tflite_micro import test_utils +from tflite_micro.tensorflow.lite.tools import visualize +import unittest +import tempfile + + +class VisualizeTest(unittest.TestCase): + + def testTensorTypeToName(self): + self.assertEqual('FLOAT32', visualize.TensorTypeToName(0)) + + def testBuiltinCodeToName(self): + self.assertEqual('HASHTABLE_LOOKUP', visualize.BuiltinCodeToName(10)) + + def testFlatbufferToDict(self): + model = test_utils.build_mock_flatbuffer_model() + model_dict = visualize.CreateDictFromFlatbuffer(model) + self.assertEqual(test_utils.TFLITE_SCHEMA_VERSION, model_dict['version']) + self.assertEqual(1, len(model_dict['subgraphs'])) + self.assertEqual(2, len(model_dict['operator_codes'])) + self.assertEqual(3, len(model_dict['buffers'])) + self.assertEqual(3, len(model_dict['subgraphs'][0]['tensors'])) + self.assertEqual(0, model_dict['subgraphs'][0]['tensors'][0]['buffer']) + + def testVisualize(self): + model = test_utils.build_mock_flatbuffer_model() + tmp_dir = tempfile.mkdtemp() + model_filename = os.path.join(tmp_dir, 'model.tflite') + with open(model_filename, 'wb') as model_file: + model_file.write(model) + + html_text = visualize.create_html(model_filename) + + # It's hard to test debug output without doing a full HTML parse, + # but at least sanity check that expected identifiers are present. + self.assertRegex( + html_text, re.compile(r'%s' % model_filename, re.MULTILINE | re.DOTALL)) + self.assertRegex(html_text, + re.compile(r'input_tensor', re.MULTILINE | re.DOTALL)) + self.assertRegex(html_text, + re.compile(r'constant_tensor', re.MULTILINE | re.DOTALL)) + self.assertRegex(html_text, re.compile(r'ADD', re.MULTILINE | re.DOTALL)) + + +if __name__ == '__main__': + unittest.main() diff --git a/tensorflow/lite/python/BUILD b/tensorflow/lite/python/BUILD index 9ca2819a06b..f7a662782ed 100644 --- a/tensorflow/lite/python/BUILD +++ b/tensorflow/lite/python/BUILD @@ -20,5 +20,6 @@ py_library( visibility = ["//:__subpackages__"], deps = [ requirement("flatbuffers"), + requirement("tensorflow"), ], ) diff --git a/tensorflow/lite/python/schema_util.py b/tensorflow/lite/python/schema_util.py index 3f351f0caa4..e898a47318d 100644 --- a/tensorflow/lite/python/schema_util.py +++ b/tensorflow/lite/python/schema_util.py @@ -14,6 +14,7 @@ # ============================================================================== """Schema utilities to get builtin code from operator code.""" +from tensorflow.python.util import all_util def get_builtin_code_from_operator_code(opcode): @@ -37,3 +38,8 @@ def get_builtin_code_from_operator_code(opcode): return max(opcode.builtinCode, opcode.deprecatedBuiltinCode) +_allowed_symbols = [ + 'get_builtin_code_from_operator_code', +] + +all_util.remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/lite/tools/BUILD b/tensorflow/lite/tools/BUILD index a84dadde5a7..0250959943d 100644 --- a/tensorflow/lite/tools/BUILD +++ b/tensorflow/lite/tools/BUILD @@ -8,6 +8,7 @@ py_library( deps = [ "//:tflite_micro_shim", requirement("flatbuffers"), + requirement("tensorflow"), "//tensorflow/lite/python:schema_py", "//tensorflow/lite/python:schema_util", ], @@ -19,6 +20,7 @@ py_library( deps = [ "//:tflite_micro_shim", requirement("flatbuffers"), + requirement("tensorflow"), "//tensorflow/lite/python:schema_py", ], ) @@ -48,6 +50,7 @@ py_test( deps = [ ":flatbuffer_utils", ":test_utils", + requirement("tensorflow"), ], ) @@ -57,5 +60,6 @@ py_test( deps = [ ":test_utils", ":visualize", + requirement("tensorflow"), ], ) diff --git a/tensorflow/lite/tools/flatbuffer_utils.py b/tensorflow/lite/tools/flatbuffer_utils.py index a3495e18bf3..71e1afed6b5 100644 --- a/tensorflow/lite/tools/flatbuffer_utils.py +++ b/tensorflow/lite/tools/flatbuffer_utils.py @@ -31,7 +31,7 @@ from tflite_micro.tensorflow.lite.python import schema_py_generated as schema_fb from tflite_micro.tensorflow.lite.python import schema_util -import os +from tensorflow.python.platform import gfile _TFLITE_FILE_IDENTIFIER = b'TFL3' @@ -55,9 +55,9 @@ def read_model(input_tflite_file): Returns: A python object corresponding to the input tflite file. """ - if not os.path.exists(input_tflite_file): + if not gfile.Exists(input_tflite_file): raise RuntimeError('Input file not found at %r\n' % input_tflite_file) - with open(input_tflite_file, 'rb') as input_file_handle: + with gfile.GFile(input_tflite_file, 'rb') as input_file_handle: model_bytearray = bytearray(input_file_handle.read()) return read_model_from_bytearray(model_bytearray) @@ -144,7 +144,7 @@ def write_model(model_object, output_tflite_file): model_object = copy.deepcopy(model_object) byte_swap_tflite_model_obj(model_object, 'big', 'little') model_bytearray = convert_object_to_bytearray(model_object) - with open(output_tflite_file, 'wb') as output_file_handle: + with gfile.GFile(output_tflite_file, 'wb') as output_file_handle: output_file_handle.write(model_bytearray) diff --git a/tensorflow/lite/tools/flatbuffer_utils_test.py b/tensorflow/lite/tools/flatbuffer_utils_test.py index 6fab6e7e832..13074aaca5e 100644 --- a/tensorflow/lite/tools/flatbuffer_utils_test.py +++ b/tensorflow/lite/tools/flatbuffer_utils_test.py @@ -21,20 +21,20 @@ from tflite_micro.tensorflow.lite.python import schema_py_generated as schema # pylint:disable=g-direct-tensorflow-import from tflite_micro.tensorflow.lite.tools import flatbuffer_utils from tflite_micro.tensorflow.lite.tools import test_utils -import unittest -import tempfile +from tensorflow.python.framework import test_util +from tensorflow.python.platform import test _SKIPPED_BUFFER_INDEX = 1 -class WriteReadModelTest(unittest.TestCase): +class WriteReadModelTest(test_util.TensorFlowTestCase): def testWriteReadModel(self): # 1. SETUP # Define the initial model initial_model = test_utils.build_mock_model() # Define temporary files - tmp_dir = tempfile.mkdtemp() + tmp_dir = self.get_temp_dir() model_filename = os.path.join(tmp_dir, 'model.tflite') # 2. INVOKE @@ -72,7 +72,7 @@ def testWriteReadModel(self): self.assertEqual(initial_buffer.data[i], final_buffer.data[i]) -class StripStringsTest(unittest.TestCase): +class StripStringsTest(test_util.TensorFlowTestCase): def testStripStrings(self): # 1. SETUP @@ -120,7 +120,7 @@ def testStripStrings(self): self.assertEqual(initial_buffer.data[i], final_buffer.data[i]) -class RandomizeWeightsTest(unittest.TestCase): +class RandomizeWeightsTest(test_util.TensorFlowTestCase): def testRandomizeWeights(self): # 1. SETUP @@ -204,7 +204,7 @@ def testRandomizeSomeWeights(self): self.assertEqual(initial_buffer.data[j], final_buffer.data[j]) -class XxdOutputToBytesTest(unittest.TestCase): +class XxdOutputToBytesTest(test_util.TensorFlowTestCase): def testXxdOutputToBytes(self): # 1. SETUP @@ -213,7 +213,7 @@ def testXxdOutputToBytes(self): initial_bytes = flatbuffer_utils.convert_object_to_bytearray(initial_model) # Define temporary files - tmp_dir = tempfile.mkdtemp() + tmp_dir = self.get_temp_dir() model_filename = os.path.join(tmp_dir, 'model.tflite') # 2. Write model to temporary file (will be used as input for xxd) @@ -236,7 +236,7 @@ def testXxdOutputToBytes(self): self.assertEqual(initial_bytes, final_bytes) -class CountResourceVariablesTest(unittest.TestCase): +class CountResourceVariablesTest(test_util.TensorFlowTestCase): def testCountResourceVariables(self): # 1. SETUP @@ -250,7 +250,7 @@ def testCountResourceVariables(self): flatbuffer_utils.count_resource_variables(initial_model), 1) -class GetOptionsTest(unittest.TestCase): +class GetOptionsTest(test_util.TensorFlowTestCase): op: schema.Operator op_t: schema.OperatorT @@ -290,4 +290,4 @@ def test_get_options_op_type_does_not_match(self): if __name__ == '__main__': - unittest.main() + test.main() diff --git a/tensorflow/lite/tools/visualize_test.py b/tensorflow/lite/tools/visualize_test.py index 4deb5a7ccbb..68de38cc9d7 100644 --- a/tensorflow/lite/tools/visualize_test.py +++ b/tensorflow/lite/tools/visualize_test.py @@ -18,11 +18,11 @@ from tflite_micro.tensorflow.lite.tools import test_utils from tflite_micro.tensorflow.lite.tools import visualize -import unittest -import tempfile +from tensorflow.python.framework import test_util +from tensorflow.python.platform import test -class VisualizeTest(unittest.TestCase): +class VisualizeTest(test_util.TensorFlowTestCase): def testTensorTypeToName(self): self.assertEqual('FLOAT32', visualize.TensorTypeToName(0)) @@ -42,7 +42,7 @@ def testFlatbufferToDict(self): def testVisualize(self): model = test_utils.build_mock_flatbuffer_model() - tmp_dir = tempfile.mkdtemp() + tmp_dir = self.get_temp_dir() model_filename = os.path.join(tmp_dir, 'model.tflite') with open(model_filename, 'wb') as model_file: model_file.write(model) @@ -61,4 +61,4 @@ def testVisualize(self): if __name__ == '__main__': - unittest.main() + test.main() From 1c83043c2d350c540cae8cafb09d2f2cac65ea5d Mon Sep 17 00:00:00 2001 From: Esun Kim Date: Wed, 6 May 2026 10:13:24 -0700 Subject: [PATCH 3/5] Vendoering revise --- python/tflite_micro/BUILD | 41 +- python/tflite_micro/flatbuffer_utils.py | 537 ------------------- python/tflite_micro/flatbuffer_utils_test.py | 293 ---------- python/tflite_micro/runtime.py | 35 +- python/tflite_micro/test_utils.py | 299 ----------- python/tflite_micro/visualize_test.py | 64 --- 6 files changed, 33 insertions(+), 1236 deletions(-) delete mode 100644 python/tflite_micro/flatbuffer_utils.py delete mode 100644 python/tflite_micro/flatbuffer_utils_test.py delete mode 100644 python/tflite_micro/test_utils.py delete mode 100644 python/tflite_micro/visualize_test.py diff --git a/python/tflite_micro/BUILD b/python/tflite_micro/BUILD index 9768b4a540d..0fdf5b53428 100644 --- a/python/tflite_micro/BUILD +++ b/python/tflite_micro/BUILD @@ -71,45 +71,6 @@ pybind_extension( ], ) -py_library( - name = "flatbuffer_utils", - srcs = ["flatbuffer_utils.py"], - visibility = ["//visibility:public"], - deps = [ - "//:tflite_micro_shim", - requirement("flatbuffers"), - "//tensorflow/lite/python:schema_py", - ], -) - -py_library( - name = "test_utils", - srcs = ["test_utils.py"], - deps = [ - "//:tflite_micro_shim", - requirement("flatbuffers"), - "//tensorflow/lite/python:schema_py", - ], -) - -py_test( - name = "flatbuffer_utils_test", - srcs = ["flatbuffer_utils_test.py"], - deps = [ - ":flatbuffer_utils", - ":test_utils", - ], -) - -py_test( - name = "visualize_test", - srcs = ["visualize_test.py"], - deps = [ - ":test_utils", - "//tensorflow/lite/tools:visualize", - ], -) - py_library( name = "runtime", srcs = [ @@ -124,7 +85,7 @@ py_library( "//:tflite_micro_shim", requirement("numpy"), "//tensorflow/lite/micro/tools:generate_test_for_model", - ":flatbuffer_utils", + "//tensorflow/lite/python:schema_py", ], ) diff --git a/python/tflite_micro/flatbuffer_utils.py b/python/tflite_micro/flatbuffer_utils.py deleted file mode 100644 index d20a8e3b0ab..00000000000 --- a/python/tflite_micro/flatbuffer_utils.py +++ /dev/null @@ -1,537 +0,0 @@ -# Copyright 2020 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Utility functions for FlatBuffers. - -All functions that are commonly used to work with FlatBuffers. - -Refer to the tensorflow lite flatbuffer schema here: -tensorflow/lite/schema/schema.fbs -""" - -import copy -import random -import re -import struct -import sys -from typing import Optional, Type, TypeVar, Union - -import flatbuffers - -from tflite_micro.tensorflow.lite.python import schema_py_generated as schema_fb -import os - - -def get_builtin_code_from_operator_code(opcode): - """Return the builtin code of the given operator code. - - The following method is introduced to resolve op builtin code shortage - problem. The new builtin operator will be assigned to the extended builtin - code field in the flatbuffer schema. Those methods helps to hide builtin code - details. - - Args: - opcode: Operator code. - - Returns: - The builtin code of the given operator code. - """ - # Access BuiltinCode() method first if available. - if hasattr(opcode, 'BuiltinCode') and callable(opcode.BuiltinCode): - return max(opcode.BuiltinCode(), opcode.DeprecatedBuiltinCode()) - - return max(opcode.builtinCode, opcode.deprecatedBuiltinCode) - -_TFLITE_FILE_IDENTIFIER = b'TFL3' - - -def convert_bytearray_to_object(model_bytearray): - """Converts a tflite model from a bytearray to an object for parsing.""" - model_object = schema_fb.Model.GetRootAsModel(model_bytearray, 0) - return schema_fb.ModelT.InitFromObj(model_object) - - -def read_model(input_tflite_file): - """Reads a tflite model as a python object. - - Args: - input_tflite_file: Full path name to the input tflite file - - Raises: - RuntimeError: If input_tflite_file path is invalid. - IOError: If input_tflite_file cannot be opened. - - Returns: - A python object corresponding to the input tflite file. - """ - if not os.path.exists(input_tflite_file): - raise RuntimeError('Input file not found at %r\n' % input_tflite_file) - with open(input_tflite_file, 'rb') as input_file_handle: - model_bytearray = bytearray(input_file_handle.read()) - return read_model_from_bytearray(model_bytearray) - - -def read_model_from_bytearray(model_bytearray): - """Reads a tflite model as a python object. - - Args: - model_bytearray: TFLite model in bytearray format. - - Returns: - A python object corresponding to the input tflite file. - """ - model = convert_bytearray_to_object(model_bytearray) - if sys.byteorder == 'big': - byte_swap_tflite_model_obj(model, 'little', 'big') - - # Offset handling for models > 2GB - for buffer in model.buffers: - if buffer.offset: - buffer.data = model_bytearray[buffer.offset : buffer.offset + buffer.size] - buffer.offset = 0 - buffer.size = 0 - for subgraph in model.subgraphs: - for op in subgraph.operators: - if op.largeCustomOptionsOffset: - op.customOptions = model_bytearray[ - op.largeCustomOptionsOffset : op.largeCustomOptionsOffset - + op.largeCustomOptionsSize - ] - op.largeCustomOptionsOffset = 0 - op.largeCustomOptionsSize = 0 - - return model - - -def read_model_with_mutable_tensors(input_tflite_file): - """Reads a tflite model as a python object with mutable tensors. - - Similar to read_model() with the addition that the returned object has - mutable tensors (read_model() returns an object with immutable tensors). - - NOTE: This API only works for TFLite generated with - _experimental_use_buffer_offset=false - - Args: - input_tflite_file: Full path name to the input tflite file - - Raises: - RuntimeError: If input_tflite_file path is invalid. - IOError: If input_tflite_file cannot be opened. - - Returns: - A mutable python object corresponding to the input tflite file. - """ - return copy.deepcopy(read_model(input_tflite_file)) - - -def convert_object_to_bytearray(model_object, extra_buffer=b''): - """Converts a tflite model from an object to a immutable bytearray.""" - # Initial size of the buffer, which will grow automatically if needed - builder = flatbuffers.Builder(1024) - model_offset = model_object.Pack(builder) - builder.Finish(model_offset, file_identifier=_TFLITE_FILE_IDENTIFIER) - model_bytearray = bytes(builder.Output()) - model_bytearray = model_bytearray + extra_buffer - return model_bytearray - - -def write_model(model_object, output_tflite_file): - """Writes the tflite model, a python object, into the output file. - - NOTE: This API only works for TFLite generated with - _experimental_use_buffer_offset=false - - Args: - model_object: A tflite model as a python object - output_tflite_file: Full path name to the output tflite file. - - Raises: - IOError: If output_tflite_file path is invalid or cannot be opened. - """ - if sys.byteorder == 'big': - model_object = copy.deepcopy(model_object) - byte_swap_tflite_model_obj(model_object, 'big', 'little') - model_bytearray = convert_object_to_bytearray(model_object) - with open(output_tflite_file, 'wb') as output_file_handle: - output_file_handle.write(model_bytearray) - - -def strip_strings(model): - """Strips all nonessential strings from the model to reduce model size. - - We remove the following strings: - (find strings by searching ":string" in the tensorflow lite flatbuffer schema) - 1. Model description - 2. SubGraph name - 3. Tensor names - We retain OperatorCode custom_code and Metadata name. - - Args: - model: The model from which to remove nonessential strings. - """ - - model.description = None - for subgraph in model.subgraphs: - subgraph.name = None - for tensor in subgraph.tensors: - tensor.name = None - # We clear all signature_def structure, since without names it is useless. - model.signatureDefs = None - - -def type_to_name(tensor_type): - """Converts a numerical enum to a readable tensor type.""" - for name, value in schema_fb.TensorType.__dict__.items(): - if value == tensor_type: - return name - return None - - -def randomize_weights(model, random_seed=0, buffers_to_skip=None): - """Randomize weights in a model. - - Args: - model: The model in which to randomize weights. - random_seed: The input to the random number generator (default value is 0). - buffers_to_skip: The list of buffer indices to skip. The weights in these - buffers are left unmodified. - """ - - # The input to the random seed generator. The default value is 0. - random.seed(random_seed) - - # Parse model buffers which store the model weights - buffers = model.buffers - buffer_ids = range(1, len(buffers)) # ignore index 0 as it's always None - if buffers_to_skip is not None: - buffer_ids = [idx for idx in buffer_ids if idx not in buffers_to_skip] - - buffer_types = {} - for graph in model.subgraphs: - for op in graph.operators: - if op.inputs is None: - break - for input_idx in op.inputs: - tensor = graph.tensors[input_idx] - buffer_types[tensor.buffer] = type_to_name(tensor.type) - - for i in buffer_ids: - buffer_i_data = buffers[i].data - buffer_i_size = 0 if buffer_i_data is None else buffer_i_data.size - if buffer_i_size == 0: - continue - - # Raw data buffers are of type ubyte (or uint8) whose values lie in the - # range [0, 255]. Those ubytes (or unint8s) are the underlying - # representation of each datatype. For example, a bias tensor of type - # int32 appears as a buffer 4 times it's length of type ubyte (or uint8). - # For floats, we need to generate a valid float and then pack it into - # the raw bytes in place. - buffer_type = buffer_types.get(i, 'INT8') - if buffer_type.startswith('FLOAT'): - format_code = 'e' if buffer_type == 'FLOAT16' else 'f' - for offset in range(0, buffer_i_size, struct.calcsize(format_code)): - value = random.uniform(-0.5, 0.5) # See http://b/152324470#comment2 - struct.pack_into(format_code, buffer_i_data, offset, value) - else: - for j in range(buffer_i_size): - buffer_i_data[j] = random.randint(0, 255) - - -def rename_custom_ops(model, map_custom_op_renames): - """Rename custom ops so they use the same naming style as builtin ops. - - Args: - model: The input tflite model. - map_custom_op_renames: A mapping from old to new custom op names. - """ - for op_code in model.operatorCodes: - if op_code.customCode: - op_code_str = op_code.customCode.decode('ascii') - if op_code_str in map_custom_op_renames: - op_code.customCode = map_custom_op_renames[op_code_str].encode('ascii') - - -def opcode_to_name(model, op_code): - """Converts a TFLite op_code to the human readable name. - - Args: - model: The input tflite model. - op_code: The op_code to resolve to a readable name. - - Returns: - A string containing the human readable op name, or None if not resolvable. - """ - op = model.operatorCodes[op_code] - code = max(op.builtinCode, op.deprecatedBuiltinCode) - for name, value in vars(schema_fb.BuiltinOperator).items(): - if value == code: - return name - return None - - -def xxd_output_to_bytes(input_cc_file): - """Converts xxd output C++ source file to bytes (immutable). - - Args: - input_cc_file: Full path name to th C++ source file dumped by xxd - - Raises: - RuntimeError: If input_cc_file path is invalid. - IOError: If input_cc_file cannot be opened. - - Returns: - A bytearray corresponding to the input cc file array. - """ - # Match hex values in the string with comma as separator - pattern = re.compile(r'\W*(0x[0-9a-fA-F,x ]+).*') - - model_bytearray = bytearray() - - with open(input_cc_file) as file_handle: - for line in file_handle: - values_match = pattern.match(line) - - if values_match is None: - continue - - # Match in the parentheses (hex array only) - list_text = values_match.group(1) - - # Extract hex values (text) from the line - # e.g. 0x1c, 0x00, 0x00, 0x00, 0x54, 0x46, 0x4c, - values_text = filter(None, list_text.split(',')) - - # Convert to hex - values = [int(x, base=16) for x in values_text] - model_bytearray.extend(values) - - return bytes(model_bytearray) - - -def xxd_output_to_object(input_cc_file): - """Converts xxd output C++ source file to object. - - Args: - input_cc_file: Full path name to th C++ source file dumped by xxd - - Raises: - RuntimeError: If input_cc_file path is invalid. - IOError: If input_cc_file cannot be opened. - - Returns: - A python object corresponding to the input tflite file. - """ - model_bytes = xxd_output_to_bytes(input_cc_file) - return convert_bytearray_to_object(model_bytes) - - -def byte_swap_buffer_content(buffer, chunksize, from_endiness, to_endiness): - """Helper function for byte-swapping the buffers field.""" - to_swap = [ - buffer.data[i : i + chunksize] - for i in range(0, len(buffer.data), chunksize) - ] - buffer.data = b''.join([ - int.from_bytes(byteswap, from_endiness).to_bytes(chunksize, to_endiness) - for byteswap in to_swap - ]) - - -def byte_swap_string_content(buffer, from_endiness, to_endiness): - """Helper function for byte-swapping the string buffer. - - Args: - buffer: TFLite string buffer of from_endiness format. - from_endiness: The original endianness format of the string buffer. - to_endiness: The destined endianness format of the string buffer. - """ - num_of_strings = int.from_bytes(buffer.data[0:4], from_endiness) - string_content = bytearray(buffer.data[4 * (num_of_strings + 2) :]) - prefix_data = b''.join([ - int.from_bytes(buffer.data[i : i + 4], from_endiness).to_bytes( - 4, to_endiness - ) - for i in range(0, (num_of_strings + 1) * 4 + 1, 4) - ]) - buffer.data = prefix_data + string_content - - -def byte_swap_tflite_model_obj(model, from_endiness, to_endiness): - """Byte swaps the buffers field in a TFLite model. - - Args: - model: TFLite model object of from_endiness format. - from_endiness: The original endianness format of the buffers in model. - to_endiness: The destined endianness format of the buffers in model. - """ - if model is None: - return - # Get all the constant buffers, byte swapping them as per their data types - buffer_swapped = [] - types_of_16_bits = [ - schema_fb.TensorType.FLOAT16, - schema_fb.TensorType.INT16, - schema_fb.TensorType.UINT16, - ] - types_of_32_bits = [ - schema_fb.TensorType.FLOAT32, - schema_fb.TensorType.INT32, - schema_fb.TensorType.COMPLEX64, - schema_fb.TensorType.UINT32, - ] - types_of_64_bits = [ - schema_fb.TensorType.INT64, - schema_fb.TensorType.FLOAT64, - schema_fb.TensorType.COMPLEX128, - schema_fb.TensorType.UINT64, - ] - for subgraph in model.subgraphs: - for tensor in subgraph.tensors: - if ( - tensor.buffer > 0 - and tensor.buffer < len(model.buffers) - and tensor.buffer not in buffer_swapped - and model.buffers[tensor.buffer].data is not None - ): - if tensor.type == schema_fb.TensorType.STRING: - byte_swap_string_content( - model.buffers[tensor.buffer], from_endiness, to_endiness - ) - elif tensor.type in types_of_16_bits: - byte_swap_buffer_content( - model.buffers[tensor.buffer], 2, from_endiness, to_endiness - ) - elif tensor.type in types_of_32_bits: - byte_swap_buffer_content( - model.buffers[tensor.buffer], 4, from_endiness, to_endiness - ) - elif tensor.type in types_of_64_bits: - byte_swap_buffer_content( - model.buffers[tensor.buffer], 8, from_endiness, to_endiness - ) - else: - continue - buffer_swapped.append(tensor.buffer) - - -def byte_swap_tflite_buffer(tflite_model, from_endiness, to_endiness): - """Generates a new model byte array after byte swapping its buffers field. - - Args: - tflite_model: TFLite flatbuffer in a byte array. - from_endiness: The original endianness format of the buffers in - tflite_model. - to_endiness: The destined endianness format of the buffers in tflite_model. - - Returns: - TFLite flatbuffer in a byte array, after being byte swapped to to_endiness - format. - """ - if tflite_model is None: - return None - # Load TFLite Flatbuffer byte array into an object. - model = convert_bytearray_to_object(tflite_model) - - # Byte swapping the constant buffers as per their data types - byte_swap_tflite_model_obj(model, from_endiness, to_endiness) - - # Return a TFLite flatbuffer as a byte array. - return convert_object_to_bytearray(model) - - -def count_resource_variables(model): - """Calculates the number of unique resource variables in a model. - - Args: - model: the input tflite model, either as bytearray or object. - - Returns: - An integer number representing the number of unique resource variables. - """ - if not isinstance(model, schema_fb.ModelT): - model = convert_bytearray_to_object(model) - unique_shared_names = set() - for subgraph in model.subgraphs: - if subgraph.operators is None: - continue - for op in subgraph.operators: - builtin_code = get_builtin_code_from_operator_code( - model.operatorCodes[op.opcodeIndex] - ) - if builtin_code == schema_fb.BuiltinOperator.VAR_HANDLE: - unique_shared_names.add(op.builtinOptions.sharedName) - return len(unique_shared_names) - - -OptsT = TypeVar('OptsT') - - -def get_options_as( - op: Union[schema_fb.Operator, schema_fb.OperatorT], opts_type: Type[OptsT] -) -> Optional[OptsT]: - """Get the options of an operator as the specified type. - - Requested type must be an object-api type (ends in 'T'). - - Args: - op: The operator to get the options from. - opts_type: The type of the options to get. - - Returns: - The options as the specified type, or None if the options are not of the - specified type. - - Raises: - ValueError: If the specified type is not a valid options type. - """ - - err = ValueError(f'Unsupported options type: {opts_type}') - type_name: str = opts_type.__name__ - if not type_name.endswith('T'): - raise err - base_type_name = type_name.removesuffix('T') - is_opt_1_type = hasattr(schema_fb.BuiltinOptions, base_type_name) - if not is_opt_1_type and not hasattr( - schema_fb.BuiltinOptions2, base_type_name - ): - raise err - - if isinstance(op, schema_fb.Operator): - if not is_opt_1_type: - enum_val = getattr(schema_fb.BuiltinOptions2, base_type_name) - opts_creator = schema_fb.BuiltinOptions2Creator - raw_ops = op.BuiltinOptions2() - actual_enum_val = op.BuiltinOptions2Type() - else: - enum_val = getattr(schema_fb.BuiltinOptions, base_type_name) - opts_creator = schema_fb.BuiltinOptionsCreator - raw_ops = op.BuiltinOptions() - actual_enum_val = op.BuiltinOptionsType() - if raw_ops is None or actual_enum_val != enum_val: - return None - return opts_creator(enum_val, raw_ops) - - elif isinstance(op, schema_fb.OperatorT): - if is_opt_1_type: - raw_ops_t = op.builtinOptions - else: - raw_ops_t = op.builtinOptions2 - if raw_ops_t is None or not isinstance(raw_ops_t, opts_type): - return None - return raw_ops_t - - else: - return None diff --git a/python/tflite_micro/flatbuffer_utils_test.py b/python/tflite_micro/flatbuffer_utils_test.py deleted file mode 100644 index 75f5b61e598..00000000000 --- a/python/tflite_micro/flatbuffer_utils_test.py +++ /dev/null @@ -1,293 +0,0 @@ -# Copyright 2020 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for flatbuffer_utils.py.""" -import copy -import os -import subprocess -import sys - -from tflite_micro.tensorflow.lite.python import schema_py_generated as schema # pylint:disable=g-direct-tensorflow-import -from tflite_micro.python.tflite_micro import flatbuffer_utils -from tflite_micro.python.tflite_micro import test_utils -import unittest -import tempfile - -_SKIPPED_BUFFER_INDEX = 1 - - -class WriteReadModelTest(unittest.TestCase): - - def testWriteReadModel(self): - # 1. SETUP - # Define the initial model - initial_model = test_utils.build_mock_model() - # Define temporary files - tmp_dir = tempfile.mkdtemp() - model_filename = os.path.join(tmp_dir, 'model.tflite') - - # 2. INVOKE - # Invoke the write_model and read_model functions - flatbuffer_utils.write_model(initial_model, model_filename) - final_model = flatbuffer_utils.read_model(model_filename) - - # 3. VALIDATE - # Validate that the initial and final models are the same - # Validate the description - self.assertEqual(initial_model.description, final_model.description) - # Validate the main subgraph's name, inputs, outputs, operators and tensors - initial_subgraph = initial_model.subgraphs[0] - final_subgraph = final_model.subgraphs[0] - self.assertEqual(initial_subgraph.name, final_subgraph.name) - for i in range(len(initial_subgraph.inputs)): - self.assertEqual(initial_subgraph.inputs[i], final_subgraph.inputs[i]) - for i in range(len(initial_subgraph.outputs)): - self.assertEqual(initial_subgraph.outputs[i], final_subgraph.outputs[i]) - for i in range(len(initial_subgraph.operators)): - self.assertEqual(initial_subgraph.operators[i].opcodeIndex, - final_subgraph.operators[i].opcodeIndex) - initial_tensors = initial_subgraph.tensors - final_tensors = final_subgraph.tensors - for i in range(len(initial_tensors)): - self.assertEqual(initial_tensors[i].name, final_tensors[i].name) - self.assertEqual(initial_tensors[i].type, final_tensors[i].type) - self.assertEqual(initial_tensors[i].buffer, final_tensors[i].buffer) - for j in range(len(initial_tensors[i].shape)): - self.assertEqual(initial_tensors[i].shape[j], final_tensors[i].shape[j]) - # Validate the first valid buffer (index 0 is always None) - initial_buffer = initial_model.buffers[1].data - final_buffer = final_model.buffers[1].data - for i in range(initial_buffer.size): - self.assertEqual(initial_buffer.data[i], final_buffer.data[i]) - - -class StripStringsTest(unittest.TestCase): - - def testStripStrings(self): - # 1. SETUP - # Define the initial model - initial_model = test_utils.build_mock_model() - final_model = copy.deepcopy(initial_model) - - # 2. INVOKE - # Invoke the strip_strings function - flatbuffer_utils.strip_strings(final_model) - - # 3. VALIDATE - # Validate that the initial and final models are the same except strings - # Validate the description - self.assertIsNotNone(initial_model.description) - self.assertIsNone(final_model.description) - self.assertIsNotNone(initial_model.signatureDefs) - self.assertIsNone(final_model.signatureDefs) - - # Validate the main subgraph's name, inputs, outputs, operators and tensors - initial_subgraph = initial_model.subgraphs[0] - final_subgraph = final_model.subgraphs[0] - self.assertIsNotNone(initial_model.subgraphs[0].name) - self.assertIsNone(final_model.subgraphs[0].name) - for i in range(len(initial_subgraph.inputs)): - self.assertEqual(initial_subgraph.inputs[i], final_subgraph.inputs[i]) - for i in range(len(initial_subgraph.outputs)): - self.assertEqual(initial_subgraph.outputs[i], final_subgraph.outputs[i]) - for i in range(len(initial_subgraph.operators)): - self.assertEqual(initial_subgraph.operators[i].opcodeIndex, - final_subgraph.operators[i].opcodeIndex) - initial_tensors = initial_subgraph.tensors - final_tensors = final_subgraph.tensors - for i in range(len(initial_tensors)): - self.assertIsNotNone(initial_tensors[i].name) - self.assertIsNone(final_tensors[i].name) - self.assertEqual(initial_tensors[i].type, final_tensors[i].type) - self.assertEqual(initial_tensors[i].buffer, final_tensors[i].buffer) - for j in range(len(initial_tensors[i].shape)): - self.assertEqual(initial_tensors[i].shape[j], final_tensors[i].shape[j]) - # Validate the first valid buffer (index 0 is always None) - initial_buffer = initial_model.buffers[1].data - final_buffer = final_model.buffers[1].data - for i in range(initial_buffer.size): - self.assertEqual(initial_buffer.data[i], final_buffer.data[i]) - - -class RandomizeWeightsTest(unittest.TestCase): - - def testRandomizeWeights(self): - # 1. SETUP - # Define the initial model - initial_model = test_utils.build_mock_model() - final_model = copy.deepcopy(initial_model) - - # 2. INVOKE - # Invoke the randomize_weights function - flatbuffer_utils.randomize_weights(final_model) - - # 3. VALIDATE - # Validate that the initial and final models are the same, except that - # the weights in the model buffer have been modified (i.e, randomized) - # Validate the description - self.assertEqual(initial_model.description, final_model.description) - # Validate the main subgraph's name, inputs, outputs, operators and tensors - initial_subgraph = initial_model.subgraphs[0] - final_subgraph = final_model.subgraphs[0] - self.assertEqual(initial_subgraph.name, final_subgraph.name) - for i in range(len(initial_subgraph.inputs)): - self.assertEqual(initial_subgraph.inputs[i], final_subgraph.inputs[i]) - for i in range(len(initial_subgraph.outputs)): - self.assertEqual(initial_subgraph.outputs[i], final_subgraph.outputs[i]) - for i in range(len(initial_subgraph.operators)): - self.assertEqual(initial_subgraph.operators[i].opcodeIndex, - final_subgraph.operators[i].opcodeIndex) - initial_tensors = initial_subgraph.tensors - final_tensors = final_subgraph.tensors - for i in range(len(initial_tensors)): - self.assertEqual(initial_tensors[i].name, final_tensors[i].name) - self.assertEqual(initial_tensors[i].type, final_tensors[i].type) - self.assertEqual(initial_tensors[i].buffer, final_tensors[i].buffer) - for j in range(len(initial_tensors[i].shape)): - self.assertEqual(initial_tensors[i].shape[j], final_tensors[i].shape[j]) - # Validate the first valid buffer (index 0 is always None) - initial_buffer = initial_model.buffers[1].data - final_buffer = final_model.buffers[1].data - for j in range(initial_buffer.size): - self.assertNotEqual(initial_buffer.data[j], final_buffer.data[j]) - - def testRandomizeSomeWeights(self): - # 1. SETUP - # Define the initial model - initial_model = test_utils.build_mock_model() - final_model = copy.deepcopy(initial_model) - - # 2. INVOKE - # Invoke the randomize_weights function, but skip the first buffer - flatbuffer_utils.randomize_weights( - final_model, buffers_to_skip=[_SKIPPED_BUFFER_INDEX]) - - # 3. VALIDATE - # Validate that the initial and final models are the same, except that - # the weights in the model buffer have been modified (i.e, randomized) - # Validate the description - self.assertEqual(initial_model.description, final_model.description) - # Validate the main subgraph's name, inputs, outputs, operators and tensors - initial_subgraph = initial_model.subgraphs[0] - final_subgraph = final_model.subgraphs[0] - self.assertEqual(initial_subgraph.name, final_subgraph.name) - for i, _ in enumerate(initial_subgraph.inputs): - self.assertEqual(initial_subgraph.inputs[i], final_subgraph.inputs[i]) - for i, _ in enumerate(initial_subgraph.outputs): - self.assertEqual(initial_subgraph.outputs[i], final_subgraph.outputs[i]) - for i, _ in enumerate(initial_subgraph.operators): - self.assertEqual(initial_subgraph.operators[i].opcodeIndex, - final_subgraph.operators[i].opcodeIndex) - initial_tensors = initial_subgraph.tensors - final_tensors = final_subgraph.tensors - for i, _ in enumerate(initial_tensors): - self.assertEqual(initial_tensors[i].name, final_tensors[i].name) - self.assertEqual(initial_tensors[i].type, final_tensors[i].type) - self.assertEqual(initial_tensors[i].buffer, final_tensors[i].buffer) - for j in range(len(initial_tensors[i].shape)): - self.assertEqual(initial_tensors[i].shape[j], final_tensors[i].shape[j]) - # Validate that the skipped buffer is unchanged. - initial_buffer = initial_model.buffers[_SKIPPED_BUFFER_INDEX].data - final_buffer = final_model.buffers[_SKIPPED_BUFFER_INDEX].data - for j in range(initial_buffer.size): - self.assertEqual(initial_buffer.data[j], final_buffer.data[j]) - - -class XxdOutputToBytesTest(unittest.TestCase): - - def testXxdOutputToBytes(self): - # 1. SETUP - # Define the initial model - initial_model = test_utils.build_mock_model() - initial_bytes = flatbuffer_utils.convert_object_to_bytearray(initial_model) - - # Define temporary files - tmp_dir = tempfile.mkdtemp() - model_filename = os.path.join(tmp_dir, 'model.tflite') - - # 2. Write model to temporary file (will be used as input for xxd) - flatbuffer_utils.write_model(initial_model, model_filename) - - # 3. DUMP WITH xxd - input_cc_file = os.path.join(tmp_dir, 'model.cc') - - command = 'xxd -i {} > {}'.format(model_filename, input_cc_file) - subprocess.call(command, shell=True) - - # 4. VALIDATE - final_bytes = flatbuffer_utils.xxd_output_to_bytes(input_cc_file) - if sys.byteorder == 'big': - final_bytes = flatbuffer_utils.byte_swap_tflite_buffer( - final_bytes, 'little', 'big' - ) - - # Validate that the initial and final bytearray are the same - self.assertEqual(initial_bytes, final_bytes) - - -class CountResourceVariablesTest(unittest.TestCase): - - def testCountResourceVariables(self): - # 1. SETUP - # Define the initial model - initial_model = test_utils.build_mock_model() - - # 2. Confirm that resource variables for mock model is 1 - # The mock model is created with two VAR HANDLE ops, but with the same - # shared name. - self.assertEqual( - flatbuffer_utils.count_resource_variables(initial_model), 1) - - -class GetOptionsTest(unittest.TestCase): - - op: schema.Operator - op_t: schema.OperatorT - - @classmethod - def setUpClass(cls): - super().setUpClass() - cls.op = test_utils.build_operator_with_options() - cls.op_t = schema.OperatorT.InitFromObj(cls.op) - - def test_get_options(self): - ty = schema.StableHLOCompositeOptionsT - opts = flatbuffer_utils.get_options_as(self.op, ty) - self.assertIsNotNone(opts) - self.assertIsInstance(opts, ty) - self.assertEqual(opts.decompositionSubgraphIndex, 10) - - def test_get_options_obj(self): - ty = schema.StableHLOCompositeOptionsT - opts = flatbuffer_utils.get_options_as(self.op_t, ty) - self.assertIsNotNone(opts) - self.assertIsInstance(opts, ty) - self.assertEqual(opts.decompositionSubgraphIndex, 10) - - def test_get_options_not_schema_type_raises(self): - with self.assertRaises(ValueError): - flatbuffer_utils.get_options_as(self.op, int) - - def test_get_options_not_object_type_raises(self): - with self.assertRaises(ValueError): - flatbuffer_utils.get_options_as(self.op, schema.StableHLOCompositeOptions) - - def test_get_options_op_type_does_not_match(self): - ty = schema.Conv2DOptionsT - opts = flatbuffer_utils.get_options_as(self.op, ty) - self.assertIsNone(opts) - - -if __name__ == '__main__': - unittest.main() diff --git a/python/tflite_micro/runtime.py b/python/tflite_micro/runtime.py index 8ef04bdc109..fda526ee266 100644 --- a/python/tflite_micro/runtime.py +++ b/python/tflite_micro/runtime.py @@ -16,8 +16,38 @@ import enum import os -from tflite_micro.python.tflite_micro import flatbuffer_utils from tflite_micro.python.tflite_micro import _runtime +from tflite_micro.tensorflow.lite.python import schema_py_generated as schema_fb + + +def convert_bytearray_to_object(model_bytearray): + """Converts a tflite model from a bytearray to an object for parsing.""" + model_object = schema_fb.Model.GetRootAsModel(model_bytearray, 0) + return schema_fb.ModelT.InitFromObj(model_object) + + +def get_builtin_code_from_operator_code(opcode): + """Return the builtin code of the given operator code.""" + if hasattr(opcode, 'BuiltinCode') and callable(opcode.BuiltinCode): + return max(opcode.BuiltinCode(), opcode.DeprecatedBuiltinCode()) + return max(opcode.builtinCode, opcode.deprecatedBuiltinCode) + + +def count_resource_variables(model): + """Calculates the number of unique resource variables in a model.""" + if not isinstance(model, schema_fb.ModelT): + model = convert_bytearray_to_object(model) + unique_shared_names = set() + for subgraph in model.subgraphs: + if subgraph.operators is None: + continue + for op in subgraph.operators: + builtin_code = get_builtin_code_from_operator_code( + model.operatorCodes[op.opcodeIndex] + ) + if builtin_code == schema_fb.BuiltinOperator.VAR_HANDLE: + unique_shared_names.add(op.builtinOptions.sharedName) + return len(unique_shared_names) class InterpreterConfig(enum.Enum): @@ -83,8 +113,7 @@ def __init__( if arena_size is None: arena_size = len(model_data) * 10 # Some models make use of resource variables ops, get the count here - num_resource_variables = flatbuffer_utils.count_resource_variables( - model_data) + num_resource_variables = count_resource_variables(model_data) print("Number of resource variables the model uses = ", num_resource_variables) diff --git a/python/tflite_micro/test_utils.py b/python/tflite_micro/test_utils.py deleted file mode 100644 index 44157143d5d..00000000000 --- a/python/tflite_micro/test_utils.py +++ /dev/null @@ -1,299 +0,0 @@ -# Copyright 2020 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Utility functions that support testing. - -All functions that can be commonly used by various tests. -""" - -import flatbuffers -from tflite_micro.tensorflow.lite.python import schema_py_generated as schema_fb - -TFLITE_SCHEMA_VERSION = 3 - - -def build_mock_flatbuffer_model(): - """Creates a flatbuffer containing an example model.""" - builder = flatbuffers.Builder(1024) - - schema_fb.BufferStart(builder) - buffer0_offset = schema_fb.BufferEnd(builder) - - schema_fb.BufferStartDataVector(builder, 12) - builder.PrependUint8(11) - builder.PrependUint8(10) - builder.PrependUint8(9) - builder.PrependUint8(8) - builder.PrependUint8(7) - builder.PrependUint8(6) - builder.PrependUint8(5) - builder.PrependUint8(4) - builder.PrependUint8(3) - builder.PrependUint8(2) - builder.PrependUint8(1) - builder.PrependUint8(0) - buffer1_data_offset = builder.EndVector() - schema_fb.BufferStart(builder) - schema_fb.BufferAddData(builder, buffer1_data_offset) - buffer1_offset = schema_fb.BufferEnd(builder) - - schema_fb.BufferStart(builder) - buffer2_offset = schema_fb.BufferEnd(builder) - - schema_fb.ModelStartBuffersVector(builder, 3) - builder.PrependUOffsetTRelative(buffer2_offset) - builder.PrependUOffsetTRelative(buffer1_offset) - builder.PrependUOffsetTRelative(buffer0_offset) - buffers_offset = builder.EndVector() - - string0_offset = builder.CreateString('input_tensor') - schema_fb.TensorStartShapeVector(builder, 3) - builder.PrependInt32(1) - builder.PrependInt32(2) - builder.PrependInt32(5) - shape0_offset = builder.EndVector() - schema_fb.TensorStart(builder) - schema_fb.TensorAddName(builder, string0_offset) - schema_fb.TensorAddShape(builder, shape0_offset) - schema_fb.TensorAddType(builder, 0) - schema_fb.TensorAddBuffer(builder, 0) - tensor0_offset = schema_fb.TensorEnd(builder) - - schema_fb.QuantizationParametersStartMinVector(builder, 5) - builder.PrependFloat32(0.5) - builder.PrependFloat32(2.0) - builder.PrependFloat32(5.0) - builder.PrependFloat32(10.0) - builder.PrependFloat32(20.0) - quant1_min_offset = builder.EndVector() - - schema_fb.QuantizationParametersStartMaxVector(builder, 5) - builder.PrependFloat32(10.0) - builder.PrependFloat32(20.0) - builder.PrependFloat32(-50.0) - builder.PrependFloat32(1.0) - builder.PrependFloat32(2.0) - quant1_max_offset = builder.EndVector() - - schema_fb.QuantizationParametersStartScaleVector(builder, 5) - builder.PrependFloat32(3.0) - builder.PrependFloat32(4.0) - builder.PrependFloat32(5.0) - builder.PrependFloat32(6.0) - builder.PrependFloat32(7.0) - quant1_scale_offset = builder.EndVector() - - schema_fb.QuantizationParametersStartZeroPointVector(builder, 5) - builder.PrependInt64(1) - builder.PrependInt64(2) - builder.PrependInt64(3) - builder.PrependInt64(-1) - builder.PrependInt64(-2) - quant1_zero_point_offset = builder.EndVector() - - schema_fb.QuantizationParametersStart(builder) - schema_fb.QuantizationParametersAddMin(builder, quant1_min_offset) - schema_fb.QuantizationParametersAddMax(builder, quant1_max_offset) - schema_fb.QuantizationParametersAddScale(builder, quant1_scale_offset) - schema_fb.QuantizationParametersAddZeroPoint(builder, - quant1_zero_point_offset) - quantization1_offset = schema_fb.QuantizationParametersEnd(builder) - - string1_offset = builder.CreateString('constant_tensor') - schema_fb.TensorStartShapeVector(builder, 3) - builder.PrependInt32(1) - builder.PrependInt32(2) - builder.PrependInt32(5) - shape1_offset = builder.EndVector() - schema_fb.TensorStart(builder) - schema_fb.TensorAddName(builder, string1_offset) - schema_fb.TensorAddShape(builder, shape1_offset) - schema_fb.TensorAddType(builder, schema_fb.TensorType.UINT8) - schema_fb.TensorAddBuffer(builder, 1) - schema_fb.TensorAddQuantization(builder, quantization1_offset) - tensor1_offset = schema_fb.TensorEnd(builder) - - string2_offset = builder.CreateString('output_tensor') - schema_fb.TensorStartShapeVector(builder, 3) - builder.PrependInt32(1) - builder.PrependInt32(2) - builder.PrependInt32(5) - shape2_offset = builder.EndVector() - schema_fb.TensorStart(builder) - schema_fb.TensorAddName(builder, string2_offset) - schema_fb.TensorAddShape(builder, shape2_offset) - schema_fb.TensorAddType(builder, 0) - schema_fb.TensorAddBuffer(builder, 2) - tensor2_offset = schema_fb.TensorEnd(builder) - - schema_fb.SubGraphStartTensorsVector(builder, 3) - builder.PrependUOffsetTRelative(tensor2_offset) - builder.PrependUOffsetTRelative(tensor1_offset) - builder.PrependUOffsetTRelative(tensor0_offset) - tensors_offset = builder.EndVector() - - schema_fb.SubGraphStartInputsVector(builder, 1) - builder.PrependInt32(0) - inputs_offset = builder.EndVector() - - schema_fb.SubGraphStartOutputsVector(builder, 1) - builder.PrependInt32(2) - outputs_offset = builder.EndVector() - - schema_fb.OperatorCodeStart(builder) - schema_fb.OperatorCodeAddBuiltinCode(builder, schema_fb.BuiltinOperator.ADD) - schema_fb.OperatorCodeAddDeprecatedBuiltinCode(builder, - schema_fb.BuiltinOperator.ADD) - schema_fb.OperatorCodeAddVersion(builder, 1) - code0_offset = schema_fb.OperatorCodeEnd(builder) - - schema_fb.OperatorCodeStart(builder) - schema_fb.OperatorCodeAddBuiltinCode(builder, - schema_fb.BuiltinOperator.VAR_HANDLE) - schema_fb.OperatorCodeAddDeprecatedBuiltinCode( - builder, schema_fb.BuiltinOperator.PLACEHOLDER_FOR_GREATER_OP_CODES) - schema_fb.OperatorCodeAddVersion(builder, 1) - code1_offset = schema_fb.OperatorCodeEnd(builder) - - schema_fb.ModelStartOperatorCodesVector(builder, 2) - builder.PrependUOffsetTRelative(code1_offset) - builder.PrependUOffsetTRelative(code0_offset) - codes_offset = builder.EndVector() - - schema_fb.OperatorStartInputsVector(builder, 2) - builder.PrependInt32(0) - builder.PrependInt32(1) - op_inputs_offset = builder.EndVector() - - schema_fb.OperatorStartOutputsVector(builder, 1) - builder.PrependInt32(2) - op_outputs_offset = builder.EndVector() - - schema_fb.OperatorStart(builder) - schema_fb.OperatorAddOpcodeIndex(builder, 0) - schema_fb.OperatorAddInputs(builder, op_inputs_offset) - schema_fb.OperatorAddOutputs(builder, op_outputs_offset) - op0_offset = schema_fb.OperatorEnd(builder) - - shared_name = builder.CreateString('var') - schema_fb.VarHandleOptionsStart(builder) - schema_fb.VarHandleOptionsAddSharedName(builder, shared_name) - var_handle_options_offset = schema_fb.VarHandleOptionsEnd(builder) - - schema_fb.OperatorStart(builder) - schema_fb.OperatorAddOpcodeIndex(builder, 1) - schema_fb.OperatorAddBuiltinOptionsType( - builder, schema_fb.BuiltinOptions.VarHandleOptions) - schema_fb.OperatorAddBuiltinOptions(builder, var_handle_options_offset) - op1_offset = schema_fb.OperatorEnd(builder) - - schema_fb.OperatorStart(builder) - schema_fb.OperatorAddBuiltinOptionsType( - builder, schema_fb.BuiltinOptions.VarHandleOptions) - schema_fb.OperatorAddBuiltinOptions(builder, var_handle_options_offset) - op2_offset = schema_fb.OperatorEnd(builder) - - schema_fb.SubGraphStartOperatorsVector(builder, 3) - builder.PrependUOffsetTRelative(op2_offset) - builder.PrependUOffsetTRelative(op1_offset) - builder.PrependUOffsetTRelative(op0_offset) - ops_offset = builder.EndVector() - - string3_offset = builder.CreateString('subgraph_name') - schema_fb.SubGraphStart(builder) - schema_fb.SubGraphAddName(builder, string3_offset) - schema_fb.SubGraphAddTensors(builder, tensors_offset) - schema_fb.SubGraphAddInputs(builder, inputs_offset) - schema_fb.SubGraphAddOutputs(builder, outputs_offset) - schema_fb.SubGraphAddOperators(builder, ops_offset) - subgraph_offset = schema_fb.SubGraphEnd(builder) - - schema_fb.ModelStartSubgraphsVector(builder, 1) - builder.PrependUOffsetTRelative(subgraph_offset) - subgraphs_offset = builder.EndVector() - - signature_key = builder.CreateString('my_key') - input_tensor_string = builder.CreateString('input_tensor') - output_tensor_string = builder.CreateString('output_tensor') - - # Signature Inputs - schema_fb.TensorMapStart(builder) - schema_fb.TensorMapAddName(builder, input_tensor_string) - schema_fb.TensorMapAddTensorIndex(builder, 1) - input_tensor = schema_fb.TensorMapEnd(builder) - - # Signature Outputs - schema_fb.TensorMapStart(builder) - schema_fb.TensorMapAddName(builder, output_tensor_string) - schema_fb.TensorMapAddTensorIndex(builder, 2) - output_tensor = schema_fb.TensorMapEnd(builder) - - schema_fb.SignatureDefStartInputsVector(builder, 1) - builder.PrependUOffsetTRelative(input_tensor) - signature_inputs_offset = builder.EndVector() - schema_fb.SignatureDefStartOutputsVector(builder, 1) - builder.PrependUOffsetTRelative(output_tensor) - signature_outputs_offset = builder.EndVector() - - schema_fb.SignatureDefStart(builder) - schema_fb.SignatureDefAddSignatureKey(builder, signature_key) - schema_fb.SignatureDefAddInputs(builder, signature_inputs_offset) - schema_fb.SignatureDefAddOutputs(builder, signature_outputs_offset) - signature_offset = schema_fb.SignatureDefEnd(builder) - schema_fb.ModelStartSignatureDefsVector(builder, 1) - builder.PrependUOffsetTRelative(signature_offset) - signature_defs_offset = builder.EndVector() - - string4_offset = builder.CreateString('model_description') - schema_fb.ModelStart(builder) - schema_fb.ModelAddVersion(builder, TFLITE_SCHEMA_VERSION) - schema_fb.ModelAddOperatorCodes(builder, codes_offset) - schema_fb.ModelAddSubgraphs(builder, subgraphs_offset) - schema_fb.ModelAddDescription(builder, string4_offset) - schema_fb.ModelAddBuffers(builder, buffers_offset) - schema_fb.ModelAddSignatureDefs(builder, signature_defs_offset) - model_offset = schema_fb.ModelEnd(builder) - builder.Finish(model_offset) - model = builder.Output() - - return model - - -def build_operator_with_options() -> schema_fb.Operator: - """Builds an operator with the given options.""" - builder = flatbuffers.Builder(1024) - schema_fb.StableHLOCompositeOptionsStart(builder) - schema_fb.StableHLOCompositeOptionsAddDecompositionSubgraphIndex(builder, 10) - opts = schema_fb.StableHLOCompositeOptionsEnd(builder) - schema_fb.OperatorStart(builder) - schema_fb.OperatorAddBuiltinOptions2(builder, opts) - schema_fb.OperatorAddBuiltinOptions2Type( - builder, schema_fb.BuiltinOptions2.StableHLOCompositeOptions - ) - op_offset = schema_fb.OperatorEnd(builder) - builder.Finish(op_offset) - return schema_fb.Operator.GetRootAs(builder.Output()) - - -def load_model_from_flatbuffer(flatbuffer_model): - """Loads a model as a python object from a flatbuffer model.""" - model = schema_fb.Model.GetRootAsModel(flatbuffer_model, 0) - model = schema_fb.ModelT.InitFromObj(model) - return model - - -def build_mock_model(): - """Creates an object containing an example model.""" - model = build_mock_flatbuffer_model() - return load_model_from_flatbuffer(model) diff --git a/python/tflite_micro/visualize_test.py b/python/tflite_micro/visualize_test.py deleted file mode 100644 index 5c2fa31cff5..00000000000 --- a/python/tflite_micro/visualize_test.py +++ /dev/null @@ -1,64 +0,0 @@ -# Copyright 2020 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""TensorFlow Lite Python Interface: Sanity check.""" -import os -import re - -from tflite_micro.python.tflite_micro import test_utils -from tflite_micro.tensorflow.lite.tools import visualize -import unittest -import tempfile - - -class VisualizeTest(unittest.TestCase): - - def testTensorTypeToName(self): - self.assertEqual('FLOAT32', visualize.TensorTypeToName(0)) - - def testBuiltinCodeToName(self): - self.assertEqual('HASHTABLE_LOOKUP', visualize.BuiltinCodeToName(10)) - - def testFlatbufferToDict(self): - model = test_utils.build_mock_flatbuffer_model() - model_dict = visualize.CreateDictFromFlatbuffer(model) - self.assertEqual(test_utils.TFLITE_SCHEMA_VERSION, model_dict['version']) - self.assertEqual(1, len(model_dict['subgraphs'])) - self.assertEqual(2, len(model_dict['operator_codes'])) - self.assertEqual(3, len(model_dict['buffers'])) - self.assertEqual(3, len(model_dict['subgraphs'][0]['tensors'])) - self.assertEqual(0, model_dict['subgraphs'][0]['tensors'][0]['buffer']) - - def testVisualize(self): - model = test_utils.build_mock_flatbuffer_model() - tmp_dir = tempfile.mkdtemp() - model_filename = os.path.join(tmp_dir, 'model.tflite') - with open(model_filename, 'wb') as model_file: - model_file.write(model) - - html_text = visualize.create_html(model_filename) - - # It's hard to test debug output without doing a full HTML parse, - # but at least sanity check that expected identifiers are present. - self.assertRegex( - html_text, re.compile(r'%s' % model_filename, re.MULTILINE | re.DOTALL)) - self.assertRegex(html_text, - re.compile(r'input_tensor', re.MULTILINE | re.DOTALL)) - self.assertRegex(html_text, - re.compile(r'constant_tensor', re.MULTILINE | re.DOTALL)) - self.assertRegex(html_text, re.compile(r'ADD', re.MULTILINE | re.DOTALL)) - - -if __name__ == '__main__': - unittest.main() From a202d365600f69fe480add9b3bd7efe47987deea Mon Sep 17 00:00:00 2001 From: Esun Kim Date: Wed, 6 May 2026 10:23:57 -0700 Subject: [PATCH 4/5] Removed unnecessary tf pythons --- ci/tflite_files.txt | 3 --- tensorflow/lite/tools/BUILD | 30 ------------------------------ 2 files changed, 33 deletions(-) diff --git a/ci/tflite_files.txt b/ci/tflite_files.txt index a95475ec9e0..351ece5cca4 100644 --- a/ci/tflite_files.txt +++ b/ci/tflite_files.txt @@ -115,9 +115,6 @@ tensorflow/lite/portable_type_to_tflitetype.h tensorflow/lite/python/schema_util.py tensorflow/lite/schema/schema_utils.h tensorflow/lite/tools/flatbuffer_utils.py -tensorflow/lite/tools/flatbuffer_utils_test.py tensorflow/lite/tools/randomize_weights.py tensorflow/lite/tools/strip_strings.py -tensorflow/lite/tools/test_utils.py tensorflow/lite/tools/visualize.py -tensorflow/lite/tools/visualize_test.py diff --git a/tensorflow/lite/tools/BUILD b/tensorflow/lite/tools/BUILD index 0250959943d..db8cbae1a0d 100644 --- a/tensorflow/lite/tools/BUILD +++ b/tensorflow/lite/tools/BUILD @@ -14,16 +14,6 @@ py_library( ], ) -py_library( - name = "test_utils", - srcs = ["test_utils.py"], - deps = [ - "//:tflite_micro_shim", - requirement("flatbuffers"), - requirement("tensorflow"), - "//tensorflow/lite/python:schema_py", - ], -) py_binary( name = "strip_strings", @@ -43,23 +33,3 @@ py_binary( requirement("numpy"), ], ) - -py_test( - name = "flatbuffer_utils_test", - srcs = ["flatbuffer_utils_test.py"], - deps = [ - ":flatbuffer_utils", - ":test_utils", - requirement("tensorflow"), - ], -) - -py_test( - name = "visualize_test", - srcs = ["visualize_test.py"], - deps = [ - ":test_utils", - ":visualize", - requirement("tensorflow"), - ], -) From 296286fd7d4bb7bc5c6d9b6445cced4ed4d22fb2 Mon Sep 17 00:00:00 2001 From: Esun Kim Date: Wed, 6 May 2026 10:31:05 -0700 Subject: [PATCH 5/5] Reformat --- python/tflite_micro/runtime.py | 3 +-- tensorflow/lite/micro/tools/layer_by_layer_debugger.py | 7 +++---- tensorflow/lite/tools/BUILD | 1 - 3 files changed, 4 insertions(+), 7 deletions(-) diff --git a/python/tflite_micro/runtime.py b/python/tflite_micro/runtime.py index fda526ee266..d895f8c4993 100644 --- a/python/tflite_micro/runtime.py +++ b/python/tflite_micro/runtime.py @@ -43,8 +43,7 @@ def count_resource_variables(model): continue for op in subgraph.operators: builtin_code = get_builtin_code_from_operator_code( - model.operatorCodes[op.opcodeIndex] - ) + model.operatorCodes[op.opcodeIndex]) if builtin_code == schema_fb.BuiltinOperator.VAR_HANDLE: unique_shared_names.add(op.builtinOptions.sharedName) return len(unique_shared_names) diff --git a/tensorflow/lite/micro/tools/layer_by_layer_debugger.py b/tensorflow/lite/micro/tools/layer_by_layer_debugger.py index fc601f253bd..f7ba6851c5b 100644 --- a/tensorflow/lite/micro/tools/layer_by_layer_debugger.py +++ b/tensorflow/lite/micro/tools/layer_by_layer_debugger.py @@ -21,6 +21,7 @@ from absl import flags from absl import logging import numpy as np + OpResolverType = None try: import ai_edge_litert.interpreter as tflite_interp @@ -41,8 +42,7 @@ OpResolverType = tflite_interp.experimental.OpResolverType except ImportError: raise ImportError( - "Could not import ai_edge_litert, tflite_runtime, or tensorflow." - ) + "Could not import ai_edge_litert, tflite_runtime, or tensorflow.") from tflite_micro.tensorflow.lite.tools import flatbuffer_utils from tflite_micro.python.tflite_micro import runtime @@ -222,8 +222,7 @@ def main(_) -> None: kwargs["experimental_op_resolver_type"] = OpResolverType.BUILTIN_REF else: logging.warning( - "Could not find OpResolverType. Reference kernels might not be used." - ) + "Could not find OpResolverType. Reference kernels might not be used.") tflite_interpreter = tflite_interp.Interpreter(**kwargs) diff --git a/tensorflow/lite/tools/BUILD b/tensorflow/lite/tools/BUILD index db8cbae1a0d..6316a167e3b 100644 --- a/tensorflow/lite/tools/BUILD +++ b/tensorflow/lite/tools/BUILD @@ -14,7 +14,6 @@ py_library( ], ) - py_binary( name = "strip_strings", srcs = ["strip_strings.py"],