aboutsummaryrefslogtreecommitdiff
path: root/code/sunlab/sunflow/models/discriminator.py
blob: 38bed5655795d8f82b7e5d25127d7abeb7659766 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
class Discriminator:
    """# Discriminator Model

    Constructs a discriminator model with a certain depth of intermediate
    layers of fixed size"""

    def __init__(self, model_base_directory):
        """# Discriminator Model Initialization

        - model_base_directory: The base folder directory where the model will
        be saved/ loaded"""
        self.model_base_directory = model_base_directory

    def init(self):
        """# Initialize a new Discriminator

        Expects a model parameters file to already exist in the initialization
        base directory when initializing the model"""
        from tensorflow import keras
        from tensorflow.keras import layers

        self.load_parameters()
        assert self.depth >= 0, "Depth must be non-negative"
        self.model = keras.models.Sequential()
        if self.depth == 0:
            self.model.add(
                layers.Dense(
                    1,
                    input_shape=(self.latent_size,),
                    activation=None,
                    name="discriminator_output_vector",
                )
            )
        else:
            self.model.add(
                layers.Dense(
                    self.layer_size,
                    input_shape=(self.latent_size,),
                    activation=None,
                    name="discriminator_dense_1",
                )
            )
            if self.use_leaky_relu:
                self.model.add(layers.LeakyReLU())
            else:
                self.model.add(layers.ReLU())
            if self.dropout > 0.0:
                self.model.add(layers.Dropout(self.dropout))
            for _d in range(1, self.depth):
                self.model.add(
                    layers.Dense(
                        self.layer_size,
                        activation=None,
                        name=f"discriminator_dense_{_d+1}",
                    )
                )
                if self.use_leaky_relu:
                    self.model.add(layers.LeakyReLU())
                else:
                    self.model.add(layers.ReLU())
                if self.dropout > 0.0:
                    self.model.add(layers.Dropout(self.dropout))
            self.model.add(
                layers.Dense(
                    1, activation="sigmoid", name="discriminator_output_vector"
                )
            )
        self.model._name = "Discriminator"
        return self

    def load(self):
        """# Load an existing Discriminator"""
        from os import listdir

        if "discriminator.keras" not in listdir(
            f"{self.model_base_directory}/portable/"
        ):
            return None
        import tensorflow as tf

        self.model = tf.keras.models.load_model(
            f"{self.model_base_directory}/portable/discriminator.keras", compile=False
        )
        self.model._name = "Discriminator"
        return self

    def save(self, overwrite=False):
        """# Save the current Discriminator

        - Overwrite: overwrite any existing discriminator that has been
        saved"""
        from os import listdir

        if overwrite:
            self.model.save(f"{self.model_base_directory}/portable/discriminator.keras")
            return True
        if "discriminator.keras" in listdir(f"{self.model_base_directory}/portable/"):
            return False
        self.model.save(f"{self.model_base_directory}/portable/discriminator.keras")
        return True

    def load_parameters(self):
        """# Load Discriminator Model Parameters from File
        The file needs to have the following parameters defined:
         - data_size: int
         - adversary_layer_size: int
         - latent_size: int
         - autoencoder_depth: int
         - dropout: float (set to 0. if you don't want a dropout layer)
         - use_leaky_relu: boolean"""
        from pickle import load

        with open(
            f"{self.model_base_directory}/portable/model_parameters.pkl", "rb"
        ) as phandle:
            parameters = load(phandle)
        self.data_size = parameters["data_size"]
        self.layer_size = parameters["adversary_layer_size"]
        self.latent_size = parameters["latent_size"]
        self.depth = parameters["autoencoder_depth"]
        self.dropout = parameters["dropout"]
        self.use_leaky_relu = parameters["use_leaky_relu"]

    def summary(self):
        """# Returns the summary of the Discriminator model"""
        return self.model.summary()

    def __call__(self, *args, **kwargs):
        """# Callable

        When calling the discriminator class, return the model's output"""
        return self.model(*args, **kwargs)