{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# Comparison of LSTM with different transformations\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "from ai4water.models import LSTM\nfrom ai4water.utils.utils import get_version_info\nfrom ai4water.experiments import TransformationExperiments\nfrom ai4water.hyperopt import Categorical, Integer\nfrom ai4water.utils.utils import dateandtime_now\n\nfrom ai4water.datasets import busan_beach\n\nfor k,v in get_version_info().items():\n    print(f\"{k} version: {v}\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "lookback = 14\n\ndata = busan_beach()\ninput_features = data.columns.tolist()[0:-1]\noutput_features = data.columns.tolist()[-1:]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "class MyTransformationExperiments(TransformationExperiments):\n\n    def update_paras(self, **kwargs):\n        _layers = LSTM(units=kwargs['units'],\n                       input_shape=(lookback, len(input_features)),\n                       activation=kwargs['activation'])\n\n        y_transformation = kwargs['y_transformation']\n        if y_transformation == \"none\":\n            y_transformation = None\n\n        return {\n            'model': _layers,\n            'batch_size': int(kwargs['batch_size']),\n            'lr': float(kwargs['lr']),\n            'y_transformation': y_transformation\n        }"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "cases = {\n    'model_None': {'y_transformation': 'none'},\n    'model_minmax': {'y_transformation': 'minmax'},\n    'model_zscore': {'y_transformation': 'zscore'},\n    'model_robust': {'y_transformation': 'robust'},\n    'model_quantile': {'y_transformation': 'quantile'},\n    'model_log': {'y_transformation': {'method':'log', 'treat_negatives': True, 'replace_zeros': True}},\n    \"model_pareto\": {\"y_transformation\": \"pareto\"},\n    \"model_vast\": {\"y_transformation\": \"vast\"},\n    \"model_mmad\": {\"y_transformation\": \"mmad\"}\n         }"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "search_space = [\n    Integer(low=10, high=30, name='units', num_samples=10),\n    Categorical(categories=['relu', 'elu', 'tanh', \"linear\"], name='activation'),\n    Categorical(categories=[4, 8, 12, 16, 24, 32], name='batch_size'),\n    Categorical(categories=[0.05, 0.02, 0.009, 0.007, 0.005,\n                            0.003, 0.001, 0.0009, 0.0007, 0.0005, 0.0003,\n                            0.0001, 0.00009, 0.00007, 0.00005], name='lr'),\n]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "x0 = [16, \"relu\", 32, 0.0001]\n\nexperiment = MyTransformationExperiments(\n    cases=cases,\n    input_features=input_features,\n    output_features = output_features,\n    param_space=search_space,\n    x0=x0,\n    verbosity=0,\n    epochs=100,\n    exp_name = f\"ecoli_lstm_y_exp_{dateandtime_now()}\",\n    ts_args={\"lookback\": lookback},\n    save=False\n)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "experiment.fit(data = data,  run_type='dry_run')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "experiment.compare_errors('rmse', data=data)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "experiment.compare_errors('r2', data=data)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "experiment.compare_errors('nrmse', data=data)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "experiment.taylor_plot(data=data)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "experiment.compare_edf_plots(data=data)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "experiment.compare_regression_plots(data=data)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "experiment.compare_residual_plots(data=data)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "experiment.loss_comparison()"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.7.9"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}