{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "0dfb6ada-ed98-48c9-8610-cadb7493e138", "metadata": {}, "outputs": [], "source": [ "from sunlab.environment.base.cpu import *\n", "from sunlab.environment.base.extras import *\n", "from sunlab.globals import FILES\n", "from sunlab.sunflow import *" ] }, { "cell_type": "code", "execution_count": 2, "id": "4136a260-bb40-47ec-8aad-d2a6ac31f1f2", "metadata": {}, "outputs": [], "source": [ "model, dataset = load_aae_and_dataset(FILES['TRAINING_DATASET'], FILES['PRETRAINED_MODEL_DIR'], MaxAbsScaler)" ] }, { "cell_type": "markdown", "id": "00694ce2", "metadata": {}, "source": [ "# Save for PyTorch!" ] }, { "cell_type": "code", "execution_count": 22, "id": "23ae3a97", "metadata": {}, "outputs": [ { "ename": "ValueError", "evalue": "Already Saved the model Weights!", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", "Cell \u001b[0;32mIn[22], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mAlready Saved the model Weights!\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 2\u001b[0m names_tup \u001b[38;5;241m=\u001b[39m [(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mLAYER_\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mi\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m_WEIGHTS\u001b[39m\u001b[38;5;124m\"\u001b[39m,\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mLAYER_\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mi\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m_BIAS\u001b[39m\u001b[38;5;124m\"\u001b[39m,) \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(\u001b[38;5;241m1\u001b[39m,\u001b[38;5;241m3\u001b[39m\u001b[38;5;241m+\u001b[39m\u001b[38;5;241m1\u001b[39m)]\n\u001b[1;32m 3\u001b[0m names \u001b[38;5;241m=\u001b[39m []\n", "\u001b[0;31mValueError\u001b[0m: Already Saved the model Weights!" ] } ], "source": [ "raise ValueError(\"Already Saved the model Weights!\")\n", "names_tup = [(f\"LAYER_{i}_WEIGHTS\",f\"LAYER_{i}_BIAS\",) for i in range(1,3+1)]\n", "names = []\n", "for name_tup in names_tup:\n", " names.extend(name_tup)\n", "ENCODER_DICT = {}\n", "for idx, name in enumerate(names):\n", " trainable_variable = model.encoder.model.trainable_variables[idx].numpy()\n", " ENCODER_DICT[name] = trainable_variable\n", "\n", "names_tup = [(f\"LAYER_{i}_WEIGHTS\",f\"LAYER_{i}_BIAS\",) for i in range(1,3+1)]\n", "names = []\n", "for name_tup in names_tup:\n", " names.extend(name_tup)\n", "DECODER_DICT = {}\n", "for idx, name in enumerate(names):\n", " trainable_variable = model.decoder.model.trainable_variables[idx].numpy()\n", " DECODER_DICT[name] = trainable_variable\n", "\n", "names_tup = [(f\"LAYER_{i}_WEIGHTS\",f\"LAYER_{i}_BIAS\",) for i in range(1,3+1)]\n", "names = []\n", "for name_tup in names_tup:\n", " names.extend(name_tup)\n", "DISCRIMINATOR_DICT = {}\n", "for idx, name in enumerate(names):\n", " trainable_variable = model.discriminator.model.trainable_variables[idx].numpy()\n", " DISCRIMINATOR_DICT[name] = trainable_variable\n", "\n", "AAE_DICT = {\n", " \"ENCODER\": ENCODER_DICT,\n", " \"DECODER\": DECODER_DICT,\n", " \"DISCRIMINATOR\": DISCRIMINATOR_DICT,\n", "}\n", "\n", "np.save(DIR_ROOT + \"models/current_model/portable/trainable_variables.npy\", AAE_DICT)" ] }, { "cell_type": "code", "execution_count": null, "id": "3332ec53", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "tfnb", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.12" } }, "nbformat": 4, "nbformat_minor": 5 }