Skip to content

TensorFlow

edgemark.models.platforms.TensorFlow.model_generator

main

main(cfg_path=config_file_path, **kwargs)

Generate, train, evaluate, and save models based on the given configuration file.

Parameters:

Name Type Description Default
cfg_path str

The path to the configuration file containing the model generation parameters. The configuration file that this path points to should contain the following keys: - model_type (str): A placeholder for the model type. This will be populated by the target model configuration. - time_tag (str): A placeholder for the time tag. This will be populated by the current time. - target_models_dir (str): Path to the directory containing the target models configurations. - datasets_dir (str): Path to the directory containing the datasets. - linkers_dir (str): Path to the directory where the generated models list will be saved. - model_path (str): Path to the model file. - model_save_dir (str): Path to the directory where the generated model will be saved. - data_save_dir (str): Path to the directory where the representative and equality check data will be saved. - TFLM_info_save_path (str): Path to the file where the TFLM info will be saved. - wandb_online (bool): Flag to enable or disable the W&B online mode. - wandb_project_name (str): Name of the W&B project. - train_models (bool): Flag to enable or disable model training. - evaluate_models (bool): Flag to enable or disable model evaluation. - measure_execution_time (bool): Flag to enable or disable the measurement of execution time. - epochs (int): Number of epochs for training the model. If specified, it will override the number of epochs in the model configuration. - n_representative_data (int): Number of samples to be saved for TFLite conversion. - n_eqcheck_data (int): Number of samples to be saved for equivalence check of the model on PC and MCU.

config_file_path
**kwargs dict

Keyword arguments to be passed to the configuration file.

{}

Returns:

Type Description
list

A list of dictionaries containing the following keys for each target model: - name (str): Name of the target model configuration file. - result (str): Result of the model generation. It can be either "success" or "failed". - error (str): Error message in case of failure. - traceback (str): Traceback in case of failure.

Source code in edgemark/models/platforms/TensorFlow/model_generator.py
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
def main(cfg_path=config_file_path, **kwargs):
    """
    Generate, train, evaluate, and save models based on the given configuration file.

    Args:
        cfg_path (str): The path to the configuration file containing the model generation parameters.
            The configuration file that this path points to should contain the following keys:
                - model_type (str): A placeholder for the model type. This will be populated by the target model configuration.
                - time_tag (str): A placeholder for the time tag. This will be populated by the current time.
                - target_models_dir (str): Path to the directory containing the target models configurations.
                - datasets_dir (str): Path to the directory containing the datasets.
                - linkers_dir (str): Path to the directory where the generated models list will be saved.
                - model_path (str): Path to the model file.
                - model_save_dir (str): Path to the directory where the generated model will be saved.
                - data_save_dir (str): Path to the directory where the representative and equality check data will be saved.
                - TFLM_info_save_path (str): Path to the file where the TFLM info will be saved.
                - wandb_online (bool): Flag to enable or disable the W&B online mode.
                - wandb_project_name (str): Name of the W&B project.
                - train_models (bool): Flag to enable or disable model training.
                - evaluate_models (bool): Flag to enable or disable model evaluation.
                - measure_execution_time (bool): Flag to enable or disable the measurement of execution time.
                - epochs (int): Number of epochs for training the model. If specified, it will override the number of epochs in the model configuration.
                - n_representative_data (int): Number of samples to be saved for TFLite conversion.
                - n_eqcheck_data (int): Number of samples to be saved for equivalence check of the model on PC and MCU.
        **kwargs (dict): Keyword arguments to be passed to the configuration file.

    Returns:
        list: A list of dictionaries containing the following keys for each target model:
            - name (str): Name of the target model configuration file.
            - result (str): Result of the model generation. It can be either "success" or "failed".
            - error (str): Error message in case of failure.
            - traceback (str): Traceback in case of failure.
    """
    cfg = OmegaConf.load(cfg_path)
    cfg.update(OmegaConf.create(kwargs))

    if not cfg.wandb_online:
        os.environ['WANDB_MODE'] = 'offline'
    models_list = []

    target_files = find_target_files(cfg.target_models_dir)

    output = [{"name": os.path.splitext(target_file)[0]} for target_file in target_files]

    for i, target_file in enumerate(target_files):
        try:
            model_cfg_path = os.path.join(cfg.target_models_dir, target_file)
            model_cfg = OmegaConf.load(model_cfg_path)
            cfg.model_type = model_cfg.model_type
            cfg.time_tag = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
            if "epochs" in cfg:
                model_cfg.epochs = cfg.epochs
            if "dataset" in model_cfg:
                model_cfg.dataset.path = os.path.join(cfg.datasets_dir, model_cfg.dataset.name, "data.py")

            wandb_name = datetime.datetime.strptime(cfg.time_tag, "%Y-%m-%d_%H-%M-%S").strftime("%Y-%m-%d %H:%M:%S")
            wandb_dir = get_abs_path(os.path.join(cfg.model_save_dir, 'tf'))
            os.makedirs(wandb_dir, exist_ok=True)
            wandb.init(project=cfg.wandb_project_name, group=model_cfg.model_type, tags=[model_cfg.model_type, model_cfg.dataset.name, os.path.splitext(target_file)[0]], name=wandb_name, dir=wandb_dir)

            spec = importlib.util.spec_from_file_location("imported_module", cfg.model_path)
            imported_module = importlib.util.module_from_spec(spec)
            spec.loader.exec_module(imported_module)

            title = "Creating the {} model described in {} ({}/{})".format(model_cfg.model_type, target_file, i+1, len(target_files))
            print("\n")
            print("="*80)
            print("-"*((80-len(title)-2)//2), end=" ")
            print(title, end=" ")
            print("-"*((80-len(title)-2)//2))
            print("="*80)

            supervisor = imported_module.ModelSupervisor(OmegaConf.to_container(model_cfg, resolve=True))

            print("Saving representative data to the directory: {} ...".format(cfg.data_save_dir), end=" ", flush=True)
            supervisor.save_representative_data(cfg.n_representative_data, cfg.data_save_dir)
            print("Done\n")

            try:
                supervisor.compile_model(fine_tuning=False)
            except Exception as e:
                print("Error in compiling the model: {}".format(e))
                print("Continuing without compilation")

            # print("Model summary:")
            # supervisor.model.summary()
            # print("")
            total_params, trainable_params, non_trainable_params = supervisor.get_params_count()
            MACs = supervisor.get_FLOPs() // 2

            if cfg.train_models:
                print("Training the model ...")
                tensorboard_log_dir = os.path.join(cfg.model_save_dir, 'tf/logs')
                best_weights_dir = os.path.join(cfg.model_save_dir, 'tf/weights/weights_best')
                supervisor.train_model(fine_tuning=False, tensorboard_log_dir=tensorboard_log_dir, best_weights_dir=best_weights_dir, use_wandb=True)
                print("")

            evaluation_result = None
            if cfg.evaluate_models:
                try:
                    evaluation_result = supervisor.evaluate_model()
                    for metric, value in evaluation_result.items():
                        print(metric, ":", value)
                    print("")
                except Exception as e:
                    print("Error in evaluating the model: {}".format(e))
                    print("Continuing without evaluation")

            print("Saving model and weights to the directory: {} ...".format(cfg.model_save_dir), end=" ", flush=True)
            supervisor.save_model(os.path.join(cfg.model_save_dir, "tf/model"))
            supervisor.save_weights(os.path.join(cfg.model_save_dir, 'tf/weights/weights_last'))
            print("Done\n")
            supervisor.log_model_to_wandb(os.path.join(cfg.model_save_dir, "tf/model"), os.path.splitext(target_file)[0].replace("/", "_"))

            print("Saving equality check data to the directory: {} ...".format(cfg.data_save_dir), end=" ", flush=True)
            supervisor.save_eqcheck_data(cfg.n_eqcheck_data, cfg.data_save_dir)
            print("Done\n")

            if cfg.measure_execution_time:
                print("Measuring execution time ...")
                execution_time = supervisor.measure_execution_time()
                print("Average run time: {} ms\n".format(execution_time))

            model_info = {"Description": ""}
            model_info["setting_file"] = target_file
            model_info["model_type"] = model_cfg.model_type
            model_info["trained"] = cfg.train_models
            model_info.update(supervisor.get_model_info())
            model_info["total_params"] = total_params
            model_info["trainable_params"] = trainable_params
            model_info["non_trainable_params"] = non_trainable_params
            model_info["MACs"] = MACs
            if evaluation_result is not None:
                for metric, value in evaluation_result.items():
                    if not isinstance(metric, str):
                        metric = str(metric)
                    model_info[metric] = value
            if cfg.measure_execution_time:
                model_info["execution_time"] = execution_time
            model_info["wandb_name"] = wandb_name

            print("Saving the model info in the directory: {} ...".format(cfg.model_save_dir), end=" ", flush=True)
            supervisor.save_model_info(model_info, cfg.model_save_dir)
            print("Done\n")
            wandb.config.update(model_info)

            try:
                print("Saving the TFLM info in: {} ...".format(cfg.TFLM_info_save_path), end=" ", flush=True)
                supervisor.save_TFLM_info(cfg.TFLM_info_save_path)
                print("Done\n")
            except Exception as e:
                print("Error in saving the TFLM info: {}".format(e))
                print("TFLM info will not be saved. Please fix this issue if you want to use the TFLM converter later.")

            models_list.append(cfg.model_save_dir)

            wandb.finish()

            output[i]["result"] = "success"

        except Exception as e:
            output[i]["result"] = "failed"
            output[i]["error"] = type(e).__name__
            output[i]["traceback"] = traceback.format_exc()
            print("Error in generating the model:")
            print(traceback.format_exc())

    print("Saving the generated models list in the directory: {} ...".format(cfg.linkers_dir), end=" ", flush=True)
    save_models_list(models_list, cfg.linkers_dir)
    print("Done\n")

    return output

save_models_list

save_models_list(models_list, save_dir)

Saves the list of generated models.

Parameters:

Name Type Description Default
models_list list

The list of generated models.

required
save_dir str

The directory where the models list should be saved.

required
Source code in edgemark/models/platforms/TensorFlow/model_generator.py
22
23
24
25
26
27
28
29
30
31
32
def save_models_list(models_list, save_dir):
    """
    Saves the list of generated models.

    Args:
        models_list (list): The list of generated models.
        save_dir (str): The directory where the models list should be saved.
    """
    os.makedirs(save_dir, exist_ok=True)
    with open(os.path.join(save_dir, "tf_generated_models_list.yaml"), 'w') as f:
        yaml.dump(models_list, f, indent=4, sort_keys=False)

edgemark.models.platforms.TensorFlow.model_template

This module contains a template class that other models should inherit from.

ModelSupervisorTemplate

This class is a template for TensorFlow models. In order to create a new model, you should inherit from this class and implement its abstract functions.

Source code in edgemark/models/platforms/TensorFlow/model_template.py
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
class ModelSupervisorTemplate:
    """
    This class is a template for TensorFlow models. In order to create a new model,
    you should inherit from this class and implement its abstract functions.
    """

    def __init__(self, cfg=None):
        """
        Initializes the class.
        The following attributes should be set in the __init__ function:
            self.model (tf.keras.Model): The model.
            self.dataset (DatasetSupervisorTemplate): The dataset.

        Args:
            cfg (dict): The configurations of the model. Defaults to None.
        """
        self.model = None
        self.dataset = None


    def set_configs(self, cfg):
        """
        Sets the configs from the given dictionary.

        Note: The changed configs won't affect the data, model, or any other loaded attributes.
        In case you want to change them, you should call the corresponding functions.

        Args:
            cfg (dict): The configurations.
        """
        raise NotImplementedError


    # Optional function: Whether you implement this function or not depends on your application.
    def compile_model(self, fine_tuning=False):
        """
        Compiles the model.

        Args:
            fine_tuning (bool, optional): If True, the model will be compiled for fine-tuning. Defaults to False.
        """
        raise NotImplementedError


    # Optional function: Whether you implement this function or not depends on your application.
    def train_model(self, fine_tuning=False, tensorboard_log_dir=None, best_weights_dir=None, use_wandb=False):
        """
        Trains the model.

        Args:
            fine_tuning (bool, optional): If True, the model will be trained for fine-tuning. Defaults to False.
            tensorboard_log_dir (str, optional): The directory where the logs should be saved. If None, the logs won't be saved. Defaults to None.
            best_weights_dir (str, optional): The directory where the best weights should be saved. If None, the best weights won't be saved. Defaults to None.
            use_wandb (bool, optional): If True, the training progress will be logged to W&B. Defaults to False.

        Returns:
            Optional[tf.keras.callbacks.History]: The training history or None.
        """
        raise NotImplementedError


    # Optional function: Whether you implement this function or not depends on your application.
    def evaluate_model(self):
        """
        Evaluates the model.

        Returns:
            dict: The evaluation metrics.
        """
        raise NotImplementedError


    def get_model_info(self):
        """
        Returns the model info that can be anything important, including its configuration.

        Returns:
            dict: The model info.
        """
        raise NotImplementedError


    def get_params_count(self):
        """
        Returns the number of parameters in the model.

        Returns:
            list[int]: The total number of parameters, the number of trainable parameters, and the number of non-trainable parameters.
        """
        total_params = 0
        trainable_params = 0
        non_trainable_params = 0

        for layer in self.model.variables:
            total_params += np.prod(layer.shape)

        for layer in self.model.trainable_variables:
            trainable_params += np.prod(layer.shape)

        non_trainable_params = total_params - trainable_params

        return int(total_params), int(trainable_params), int(non_trainable_params)


    def get_FLOPs(self):
        """
        Returns the number of FLOPs of the model.

        Returns:
            int: The number of FLOPs.
        """
        input_signature = [
            tf.TensorSpec(
                shape=(1, *params.shape[1:]),
                dtype=params.dtype,
                name=params.name
            ) for params in self.model.inputs
        ]
        forward_graph = tf.function(self.model, input_signature).get_concrete_function().graph
        options = option_builder.ProfileOptionBuilder.float_operation()
        options['output'] = 'none'
        graph_info = model_analyzer.profile(forward_graph, options=options)

        FLOPs = graph_info.total_float_ops

        return FLOPs


    def measure_execution_time(self):
        """
        Measures the execution time of the model.
        The process starts by a warm-up phase for 100 iterations, then the execution time is measured for ~10 seconds.

        Returns:
            float: The execution time in ms.
        """
        rng = np.random.RandomState(42)
        sample_idx = rng.randint(0, self.dataset.train_x.shape[0])
        x = np.array([self.dataset.train_x[sample_idx]])

        # warm up
        tic = time.time()
        for i in range(100):
            self.model(x, training=False)
        toc = time.time()
        itr = int(10 * 100 / (toc - tic))

        # run the test
        tic = time.time()
        for i in range(itr):
            self.model(x, training=False)
        toc = time.time()
        execution_time = (toc-tic)/itr*1000     # in ms

        return execution_time


    def save_representative_data(self, n_samples, save_dir):
        """
        Saves the representative data with shape (samples, *input_shape).

        Args:
            n_samples (int): The number of samples to be saved.
            save_dir (str): The directory where the data should be saved.
        """
        data_x = self.dataset.train_x[:n_samples]
        os.makedirs(save_dir, exist_ok=True)
        np.save(os.path.join(save_dir, 'representative_data.npy'), data_x)


    # Optional function
    @staticmethod
    def load_representative_data(load_dir):
        """
        Loads the representative data.

        Args:
            load_dir (str): The directory where the representative data is stored.

        Returns:
            numpy.ndarray: The representative data.
        """
        representative_data = np.load(os.path.join(load_dir, 'representative_data.npy'))
        return representative_data


    def save_eqcheck_data(self, n_samples, save_dir):
        """
        Saves the eqcheck data as {"data_x", "data_y_pred"}.

        The data_x has shape (samples, *input_shape) and data_y_pred has shape (samples, *output_shape).

        Args:
            n_samples (int): The number of samples to be saved
            save_dir (str): The directory where the data should be saved
        """
        # sanity check (useful for when deploying the model using TFLM)
        for data_dim_size, model_dim_size in zip(self.dataset.train_x.shape[1:], self.model.inputs[0].shape[1:]):
            if data_dim_size != model_dim_size and model_dim_size is not None:
                raise ValueError("The shape of the train_x doesn't match the input shape of the model.")

        data_x = self.dataset.train_x[:n_samples]
        data_y_pred = self.model.predict(data_x, verbose=0)

        np.savez(os.path.join(save_dir, 'eqcheck_data.npz'), data_x=data_x, data_y_pred=data_y_pred)


    # Optional function
    @staticmethod
    def load_eqcheck_data(load_dir):
        """
        Loads the eqcheck data.

        Args:
            load_dir (str): The directory where the eqcheck data is stored.

        Returns:
            tuple: (data_x, data_y_pred)
        """
        eqcheck_data = np.load(os.path.join(load_dir, 'eqcheck_data.npz'))
        data_x = eqcheck_data['data_x']
        data_y_pred = eqcheck_data['data_y_pred']
        return data_x, data_y_pred


    def save_model(self, save_dir):
        """
        Saves the model in two formats: Keras and SavedModel.

        Args:
            save_dir (str): The directory where the model should be saved in.
        """
        # save the model as a Keras format
        os.makedirs(os.path.join(save_dir, "keras_format"), exist_ok=True)
        self.model.save(os.path.join(save_dir, "keras_format/model.keras"))

        # save the model as a SavedModel format
        os.makedirs(os.path.join(save_dir, "saved_model_format"), exist_ok=True)
        self.model.save(os.path.join(save_dir, "saved_model_format"))


    def load_model(self, load_dir):
        """
        Loads the model in the SavedModel format.

        Args:
            load_dir (str): The parent directory where the SavedModel format is stored in.
        """
        self.model = tf.keras.models.load_model(os.path.join(load_dir, "saved_model_format"))


    @staticmethod
    def log_model_to_wandb(model_dir, model_save_name):
        """
        Logs the model to W&B.

        Args:
            model_dir (str): The directory where the model is stored.
            model_save_name (str): The name that will be assigned to the model artifact.
        """
        model_path = os.path.join(model_dir, "keras_format/model.keras")
        model_artifact = wandb.Artifact(model_save_name, type="model")
        model_artifact.add_file(model_path)
        wandb.log_artifact(model_artifact)


    def save_weights(self, save_dir):
        """
        Saves the model weights.

        Args:
            save_dir (str): The directory where the model weights should be saved.
        """
        os.makedirs(save_dir, exist_ok=True)
        self.model.save_weights(os.path.join(save_dir, "weights"))


    def load_weights(self, load_dir):
        """
        Loads the model weights.

        Args:
            load_dir (str): The directory where the model weights are stored.
        """
        self.model.load_weights(os.path.join(load_dir, "weights"))


    @staticmethod
    def save_model_info(model_info, save_dir):
        """
        Saves the model info.

        Args:
            model_info (dict): The model info.
            save_dir (str): The directory where the model info should be saved.
        """
        os.makedirs(save_dir, exist_ok=True)
        yaml.Dumper.ignore_aliases = lambda *args : True
        with open(os.path.join(save_dir, "model_info.yaml"), 'w') as f:
            yaml.dump(model_info, f, indent=4, sort_keys=False)


    def save_TFLM_info(self, save_path):
        """
        Saves the information required by the TFLM converter as a YAML file.
        This is to help TFLM converter in a later stage.

        Args:
            save_path (str): The YAML file path where the TFLM info should be saved.
        """
        TFLM_info = {}

        arena_size_base = self._estimate_arena_size()
        TFLM_info["arena_size"] = {
            "32bit": arena_size_base,
            "16bit": arena_size_base//2,
            "8bit": arena_size_base//4
        }

        TFLM_info["input_dims"] = [int(dim) for dim in self.model.inputs[0].shape[1:]]
        TFLM_info["output_dims"] = [int(dim) for dim in self.model.outputs[0].shape[1:]]

        TFLM_info["op_resolver_funcs"] = None
        try:
            TFLM_info["op_resolver_funcs"] = self._get_op_resolver_funcs()
        except NotImplementedError:
            print("The model doesn't have an implementation for the _get_op_resolver_funcs function. The placeholders should be filled manually.")

        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        with open(save_path, 'w') as f:
            yaml.dump(TFLM_info, f, indent=4, sort_keys=False)


    def _get_training_callbacks(self, tensorboard_log_dir=None, best_weights_dir=None, use_wandb=False):
        """
        Returns the training callbacks.

        Args:
            tensorboard_log_dir (str, optional): The directory where the logs should be saved. If None, the logs won't be saved. Defaults to None.
            best_weights_dir (str, optional): The directory where the best weights should be saved. If None, the best weights won't be saved. Defaults to None.
            use_wandb (bool, optional): If True, the training progress will be logged to wandb. Defaults to False.
        """
        callbacks = []

        if tensorboard_log_dir is not None:
            callbacks.append(tf.keras.callbacks.TensorBoard(log_dir=tensorboard_log_dir, histogram_freq=1))

        if best_weights_dir is not None:
            best_weights_path = os.path.join(best_weights_dir, "weights")
            callbacks.append(tf.keras.callbacks.ModelCheckpoint(best_weights_path, save_best_only=True, save_weights_only=True, verbose=0))

        if use_wandb:
            callbacks.append(WandbCallback(save_model=False, log_weights=True, compute_flops=True))

        return callbacks


    def _estimate_arena_size(self):
        """
        Estimates the size of the arena for the TFLM model.
        This is to help the TFLM converter in a later stage.
        Note: Depending on your model architecture, you may need to override this function.

        Returns:
            int: The size of the arena in bytes.
        """

        # Note: Assuming a Sequential model. Also, assuming that TFLM is wise to do in-place operations if possible.

        def _mul_dims(dims):
            output = 1
            for dim in dims:
                output *= dim
            return output
        arena_size = 0
        layer_1_size = _mul_dims(self.model.layers[0].input_shape[1:])

        for layer in self.model.layers:
            if isinstance(layer, tf.keras.layers.InputLayer):
                continue

            elif isinstance(layer, tf.keras.layers.Dense):
                layer_2_size = _mul_dims(layer.output_shape[1:])

            elif isinstance(layer, tf.keras.layers.Conv2D):
                layer_2_size = _mul_dims(layer.output_shape[1:])

            elif isinstance(layer, tf.keras.layers.DepthwiseConv2D):
                layer_2_size = _mul_dims(layer.output_shape[1:])

            elif isinstance(layer, tf.keras.layers.MaxPooling2D):
                layer_2_size = _mul_dims(layer.output_shape[1:])

            elif isinstance(layer, tf.keras.layers.AveragePooling2D):
                layer_2_size = _mul_dims(layer.output_shape[1:])

            elif isinstance(layer, tf.keras.layers.GlobalAveragePooling2D):
                layer_2_size = _mul_dims(layer.output_shape[1:])

            elif isinstance(layer, tf.keras.layers.ZeroPadding2D):
                layer_2_size = _mul_dims(layer.output_shape[1:])

            elif isinstance(layer, tf.keras.layers.Flatten):
                continue

            elif isinstance(layer, tf.keras.layers.Add):
                continue

            elif isinstance(layer, tf.keras.layers.BatchNormalization):
                continue

            elif isinstance(layer, tf.keras.layers.Activation):
                continue

            elif isinstance(layer, tf.keras.layers.Dropout):
                continue

            elif isinstance(layer, tf.keras.layers.ReLU):
                continue

            elif isinstance(layer, tf.keras.layers.Softmax):
                continue

            else:
                raise ValueError("Unknown layer type: {}".format(layer))

            if layer_1_size + layer_2_size > arena_size:
                arena_size = layer_1_size + layer_2_size
            layer_1_size = layer_2_size

        arena_size = arena_size * 4     # 4 bytes for each float32
        return arena_size


    def _get_op_resolver_funcs(self):
        """
        Returns the operators needed to run the TFLM model.
        This is to help the TFLM converter in a later stage.
        Possible strings in the output can be found here:
            https://github.com/tensorflow/tflite-micro/blob/main/tensorflow/lite/micro/micro_mutable_op_resolver.h

        Returns:
            list[str]: The op resolvers to be called by the C++ code.

        Example:
            >>> get_op_resolver_funcs()
            ["AddFullyConnected()", "AddRelu()", "AddSoftmax()"]
        """
        output = []
        for layer in self.model.layers:
            # TODO: check if registration is needed for non-float32 operations (e.g., AddConv2D(tflite::Register_CONV_2D_INT8()))
            if isinstance(layer, tf.keras.layers.InputLayer):
                pass

            elif isinstance(layer, tf.keras.layers.Dense):
                if "AddFullyConnected()" not in output:
                    output.append("AddFullyConnected()")

            elif isinstance(layer, tf.keras.layers.Conv2D):
                if "AddConv2D()" not in output:
                    output.append("AddConv2D()")

            elif isinstance(layer, tf.keras.layers.DepthwiseConv2D):
                if "AddDepthwiseConv2D()" not in output:
                    output.append("AddDepthwiseConv2D()")

            elif isinstance(layer, tf.keras.layers.MaxPooling2D):
                if "AddMaxPool2D()" not in output:
                    output.append("AddMaxPool2D()")

            elif isinstance(layer, tf.keras.layers.AveragePooling2D):
                if "AddAveragePool2D()" not in output:
                    output.append("AddAveragePool2D()")

            elif isinstance(layer, tf.keras.layers.GlobalAveragePooling2D):
                if "AddMean()" not in output:
                    output.append("AddMean()")

            elif isinstance(layer, tf.keras.layers.ZeroPadding2D):
                if "AddPad()" not in output:
                    output.append("AddPad()")

            elif isinstance(layer, tf.keras.layers.Flatten):
                if "AddReshape()" not in output:
                    output.append("AddReshape()")

            elif isinstance(layer, tf.keras.layers.Add):
                if "AddAdd()" not in output:
                    output.append("AddAdd()")

            elif isinstance(layer, tf.keras.layers.BatchNormalization):
                pass

            elif isinstance(layer, tf.keras.layers.Embedding):
                if "AddCast()" not in output:
                    output.append("AddCast()")
                if "AddGather()" not in output:
                    output.append("AddGather()")

            elif isinstance(layer, tf.keras.layers.SimpleRNN):
                if "AddReshape()" not in output:
                    output.append("AddReshape()")
                if "AddFullyConnected()" not in output:
                    output.append("AddFullyConnected()")
                if "AddAdd()" not in output:
                    output.append("AddAdd()")
                if "AddTanh()" not in output:
                    output.append("AddTanh()")
                if "AddPack()" not in output:
                    output.append("AddPack()")
                if "AddUnpack()" not in output:
                    output.append("AddUnpack()")
                if "AddQuantize()" not in output:       # needed for int8_only quantization
                    output.append("AddQuantize()")
                if "AddDequantize()" not in output:     # needed for int8_only quantization
                    output.append("AddDequantize()")

            elif isinstance(layer, tf.keras.layers.LSTM):
                if "AddReshape()" not in output:
                    output.append("AddReshape()")
                if "AddFullyConnected()" not in output:
                    output.append("AddFullyConnected()")
                if "AddAdd()" not in output:
                    output.append("AddAdd()")
                if "AddTanh()" not in output:
                    output.append("AddTanh()")
                if "AddPack()" not in output:
                    output.append("AddPack()")
                if "AddUnpack()" not in output:
                    output.append("AddUnpack()")
                if "AddSplit()" not in output:
                    output.append("AddSplit()")
                if "AddLogistic()" not in output:
                    output.append("AddLogistic()")
                if "AddMul()" not in output:
                    output.append("AddMul()")
                if "AddQuantize()" not in output:       # needed for int8_only quantization
                    output.append("AddQuantize()")
                if "AddDequantize()" not in output:     # needed for int8_only quantization
                    output.append("AddDequantize()")

            elif isinstance(layer, tf.keras.layers.GRU):
                if "AddReshape()" not in output:
                    output.append("AddReshape()")
                if "AddFullyConnected()" not in output:
                    output.append("AddFullyConnected()")
                if "AddAdd()" not in output:
                    output.append("AddAdd()")
                if "AddTanh()" not in output:
                    output.append("AddTanh()")
                if "AddPack()" not in output:
                    output.append("AddPack()")
                if "AddUnpack()" not in output:
                    output.append("AddUnpack()")
                if "AddSplit()" not in output:
                    output.append("AddSplit()")
                if "AddLogistic()" not in output:
                    output.append("AddLogistic()")
                if "AddMul()" not in output:
                    output.append("AddMul()")
                if "AddSub()" not in output:
                    output.append("AddSub()")
                if "AddSplitV()" not in output:
                    output.append("AddSplitV()")


            elif isinstance(layer, tf.keras.layers.Activation):
                pass

            elif isinstance(layer, tf.keras.layers.Dropout):
                pass

            elif isinstance(layer, tf.keras.layers.ReLU):
                if "AddRelu()" not in output:
                    output.append("AddRelu()")

            elif isinstance(layer, tf.keras.layers.Softmax):
                if "AddSoftmax()" not in output:
                    output.append("AddSoftmax()")

            else:
                raise ValueError("Unknown layer type: {}".format(layer))

            try:
                if layer.activation is tf.keras.activations.sigmoid:
                    if "AddLogistic()" not in output:
                        output.append("AddLogistic()")

                elif layer.activation is tf.keras.activations.tanh:
                    if "AddTanh()" not in output:
                        output.append("AddTanh()")

                elif layer.activation is tf.keras.activations.relu:
                    if "AddRelu()" not in output:
                        output.append("AddRelu()")

                elif layer.activation is tf.keras.activations.softmax:
                    if "AddSoftmax()" not in output:
                        output.append("AddSoftmax()")

            except Exception as _:
                pass

        return output

__init__

__init__(cfg=None)

Initializes the class. The following attributes should be set in the init function: self.model (tf.keras.Model): The model. self.dataset (DatasetSupervisorTemplate): The dataset.

Parameters:

Name Type Description Default
cfg dict

The configurations of the model. Defaults to None.

None
Source code in edgemark/models/platforms/TensorFlow/model_template.py
22
23
24
25
26
27
28
29
30
31
32
33
def __init__(self, cfg=None):
    """
    Initializes the class.
    The following attributes should be set in the __init__ function:
        self.model (tf.keras.Model): The model.
        self.dataset (DatasetSupervisorTemplate): The dataset.

    Args:
        cfg (dict): The configurations of the model. Defaults to None.
    """
    self.model = None
    self.dataset = None

compile_model

compile_model(fine_tuning=False)

Compiles the model.

Parameters:

Name Type Description Default
fine_tuning bool

If True, the model will be compiled for fine-tuning. Defaults to False.

False
Source code in edgemark/models/platforms/TensorFlow/model_template.py
50
51
52
53
54
55
56
57
def compile_model(self, fine_tuning=False):
    """
    Compiles the model.

    Args:
        fine_tuning (bool, optional): If True, the model will be compiled for fine-tuning. Defaults to False.
    """
    raise NotImplementedError

evaluate_model

evaluate_model()

Evaluates the model.

Returns:

Type Description
dict

The evaluation metrics.

Source code in edgemark/models/platforms/TensorFlow/model_template.py
78
79
80
81
82
83
84
85
def evaluate_model(self):
    """
    Evaluates the model.

    Returns:
        dict: The evaluation metrics.
    """
    raise NotImplementedError

get_FLOPs

get_FLOPs()

Returns the number of FLOPs of the model.

Returns:

Type Description
int

The number of FLOPs.

Source code in edgemark/models/platforms/TensorFlow/model_template.py
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
def get_FLOPs(self):
    """
    Returns the number of FLOPs of the model.

    Returns:
        int: The number of FLOPs.
    """
    input_signature = [
        tf.TensorSpec(
            shape=(1, *params.shape[1:]),
            dtype=params.dtype,
            name=params.name
        ) for params in self.model.inputs
    ]
    forward_graph = tf.function(self.model, input_signature).get_concrete_function().graph
    options = option_builder.ProfileOptionBuilder.float_operation()
    options['output'] = 'none'
    graph_info = model_analyzer.profile(forward_graph, options=options)

    FLOPs = graph_info.total_float_ops

    return FLOPs

get_model_info

get_model_info()

Returns the model info that can be anything important, including its configuration.

Returns:

Type Description
dict

The model info.

Source code in edgemark/models/platforms/TensorFlow/model_template.py
88
89
90
91
92
93
94
95
def get_model_info(self):
    """
    Returns the model info that can be anything important, including its configuration.

    Returns:
        dict: The model info.
    """
    raise NotImplementedError

get_params_count

get_params_count()

Returns the number of parameters in the model.

Returns:

Type Description
list[int]

The total number of parameters, the number of trainable parameters, and the number of non-trainable parameters.

Source code in edgemark/models/platforms/TensorFlow/model_template.py
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
def get_params_count(self):
    """
    Returns the number of parameters in the model.

    Returns:
        list[int]: The total number of parameters, the number of trainable parameters, and the number of non-trainable parameters.
    """
    total_params = 0
    trainable_params = 0
    non_trainable_params = 0

    for layer in self.model.variables:
        total_params += np.prod(layer.shape)

    for layer in self.model.trainable_variables:
        trainable_params += np.prod(layer.shape)

    non_trainable_params = total_params - trainable_params

    return int(total_params), int(trainable_params), int(non_trainable_params)

load_eqcheck_data staticmethod

load_eqcheck_data(load_dir)

Loads the eqcheck data.

Parameters:

Name Type Description Default
load_dir str

The directory where the eqcheck data is stored.

required

Returns:

Type Description
tuple

(data_x, data_y_pred)

Source code in edgemark/models/platforms/TensorFlow/model_template.py
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
@staticmethod
def load_eqcheck_data(load_dir):
    """
    Loads the eqcheck data.

    Args:
        load_dir (str): The directory where the eqcheck data is stored.

    Returns:
        tuple: (data_x, data_y_pred)
    """
    eqcheck_data = np.load(os.path.join(load_dir, 'eqcheck_data.npz'))
    data_x = eqcheck_data['data_x']
    data_y_pred = eqcheck_data['data_y_pred']
    return data_x, data_y_pred

load_model

load_model(load_dir)

Loads the model in the SavedModel format.

Parameters:

Name Type Description Default
load_dir str

The parent directory where the SavedModel format is stored in.

required
Source code in edgemark/models/platforms/TensorFlow/model_template.py
257
258
259
260
261
262
263
264
def load_model(self, load_dir):
    """
    Loads the model in the SavedModel format.

    Args:
        load_dir (str): The parent directory where the SavedModel format is stored in.
    """
    self.model = tf.keras.models.load_model(os.path.join(load_dir, "saved_model_format"))

load_representative_data staticmethod

load_representative_data(load_dir)

Loads the representative data.

Parameters:

Name Type Description Default
load_dir str

The directory where the representative data is stored.

required

Returns:

Type Description
ndarray

The representative data.

Source code in edgemark/models/platforms/TensorFlow/model_template.py
187
188
189
190
191
192
193
194
195
196
197
198
199
@staticmethod
def load_representative_data(load_dir):
    """
    Loads the representative data.

    Args:
        load_dir (str): The directory where the representative data is stored.

    Returns:
        numpy.ndarray: The representative data.
    """
    representative_data = np.load(os.path.join(load_dir, 'representative_data.npy'))
    return representative_data

load_weights

load_weights(load_dir)

Loads the model weights.

Parameters:

Name Type Description Default
load_dir str

The directory where the model weights are stored.

required
Source code in edgemark/models/platforms/TensorFlow/model_template.py
293
294
295
296
297
298
299
300
def load_weights(self, load_dir):
    """
    Loads the model weights.

    Args:
        load_dir (str): The directory where the model weights are stored.
    """
    self.model.load_weights(os.path.join(load_dir, "weights"))

log_model_to_wandb staticmethod

log_model_to_wandb(model_dir, model_save_name)

Logs the model to W&B.

Parameters:

Name Type Description Default
model_dir str

The directory where the model is stored.

required
model_save_name str

The name that will be assigned to the model artifact.

required
Source code in edgemark/models/platforms/TensorFlow/model_template.py
267
268
269
270
271
272
273
274
275
276
277
278
279
@staticmethod
def log_model_to_wandb(model_dir, model_save_name):
    """
    Logs the model to W&B.

    Args:
        model_dir (str): The directory where the model is stored.
        model_save_name (str): The name that will be assigned to the model artifact.
    """
    model_path = os.path.join(model_dir, "keras_format/model.keras")
    model_artifact = wandb.Artifact(model_save_name, type="model")
    model_artifact.add_file(model_path)
    wandb.log_artifact(model_artifact)

measure_execution_time

measure_execution_time()

Measures the execution time of the model. The process starts by a warm-up phase for 100 iterations, then the execution time is measured for ~10 seconds.

Returns:

Type Description
float

The execution time in ms.

Source code in edgemark/models/platforms/TensorFlow/model_template.py
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
def measure_execution_time(self):
    """
    Measures the execution time of the model.
    The process starts by a warm-up phase for 100 iterations, then the execution time is measured for ~10 seconds.

    Returns:
        float: The execution time in ms.
    """
    rng = np.random.RandomState(42)
    sample_idx = rng.randint(0, self.dataset.train_x.shape[0])
    x = np.array([self.dataset.train_x[sample_idx]])

    # warm up
    tic = time.time()
    for i in range(100):
        self.model(x, training=False)
    toc = time.time()
    itr = int(10 * 100 / (toc - tic))

    # run the test
    tic = time.time()
    for i in range(itr):
        self.model(x, training=False)
    toc = time.time()
    execution_time = (toc-tic)/itr*1000     # in ms

    return execution_time

save_TFLM_info

save_TFLM_info(save_path)

Saves the information required by the TFLM converter as a YAML file. This is to help TFLM converter in a later stage.

Parameters:

Name Type Description Default
save_path str

The YAML file path where the TFLM info should be saved.

required
Source code in edgemark/models/platforms/TensorFlow/model_template.py
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
def save_TFLM_info(self, save_path):
    """
    Saves the information required by the TFLM converter as a YAML file.
    This is to help TFLM converter in a later stage.

    Args:
        save_path (str): The YAML file path where the TFLM info should be saved.
    """
    TFLM_info = {}

    arena_size_base = self._estimate_arena_size()
    TFLM_info["arena_size"] = {
        "32bit": arena_size_base,
        "16bit": arena_size_base//2,
        "8bit": arena_size_base//4
    }

    TFLM_info["input_dims"] = [int(dim) for dim in self.model.inputs[0].shape[1:]]
    TFLM_info["output_dims"] = [int(dim) for dim in self.model.outputs[0].shape[1:]]

    TFLM_info["op_resolver_funcs"] = None
    try:
        TFLM_info["op_resolver_funcs"] = self._get_op_resolver_funcs()
    except NotImplementedError:
        print("The model doesn't have an implementation for the _get_op_resolver_funcs function. The placeholders should be filled manually.")

    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    with open(save_path, 'w') as f:
        yaml.dump(TFLM_info, f, indent=4, sort_keys=False)

save_eqcheck_data

save_eqcheck_data(n_samples, save_dir)

Saves the eqcheck data as {"data_x", "data_y_pred"}.

The data_x has shape (samples, input_shape) and data_y_pred has shape (samples, output_shape).

Parameters:

Name Type Description Default
n_samples int

The number of samples to be saved

required
save_dir str

The directory where the data should be saved

required
Source code in edgemark/models/platforms/TensorFlow/model_template.py
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
def save_eqcheck_data(self, n_samples, save_dir):
    """
    Saves the eqcheck data as {"data_x", "data_y_pred"}.

    The data_x has shape (samples, *input_shape) and data_y_pred has shape (samples, *output_shape).

    Args:
        n_samples (int): The number of samples to be saved
        save_dir (str): The directory where the data should be saved
    """
    # sanity check (useful for when deploying the model using TFLM)
    for data_dim_size, model_dim_size in zip(self.dataset.train_x.shape[1:], self.model.inputs[0].shape[1:]):
        if data_dim_size != model_dim_size and model_dim_size is not None:
            raise ValueError("The shape of the train_x doesn't match the input shape of the model.")

    data_x = self.dataset.train_x[:n_samples]
    data_y_pred = self.model.predict(data_x, verbose=0)

    np.savez(os.path.join(save_dir, 'eqcheck_data.npz'), data_x=data_x, data_y_pred=data_y_pred)

save_model

save_model(save_dir)

Saves the model in two formats: Keras and SavedModel.

Parameters:

Name Type Description Default
save_dir str

The directory where the model should be saved in.

required
Source code in edgemark/models/platforms/TensorFlow/model_template.py
241
242
243
244
245
246
247
248
249
250
251
252
253
254
def save_model(self, save_dir):
    """
    Saves the model in two formats: Keras and SavedModel.

    Args:
        save_dir (str): The directory where the model should be saved in.
    """
    # save the model as a Keras format
    os.makedirs(os.path.join(save_dir, "keras_format"), exist_ok=True)
    self.model.save(os.path.join(save_dir, "keras_format/model.keras"))

    # save the model as a SavedModel format
    os.makedirs(os.path.join(save_dir, "saved_model_format"), exist_ok=True)
    self.model.save(os.path.join(save_dir, "saved_model_format"))

save_model_info staticmethod

save_model_info(model_info, save_dir)

Saves the model info.

Parameters:

Name Type Description Default
model_info dict

The model info.

required
save_dir str

The directory where the model info should be saved.

required
Source code in edgemark/models/platforms/TensorFlow/model_template.py
303
304
305
306
307
308
309
310
311
312
313
314
315
@staticmethod
def save_model_info(model_info, save_dir):
    """
    Saves the model info.

    Args:
        model_info (dict): The model info.
        save_dir (str): The directory where the model info should be saved.
    """
    os.makedirs(save_dir, exist_ok=True)
    yaml.Dumper.ignore_aliases = lambda *args : True
    with open(os.path.join(save_dir, "model_info.yaml"), 'w') as f:
        yaml.dump(model_info, f, indent=4, sort_keys=False)

save_representative_data

save_representative_data(n_samples, save_dir)

Saves the representative data with shape (samples, *input_shape).

Parameters:

Name Type Description Default
n_samples int

The number of samples to be saved.

required
save_dir str

The directory where the data should be saved.

required
Source code in edgemark/models/platforms/TensorFlow/model_template.py
173
174
175
176
177
178
179
180
181
182
183
def save_representative_data(self, n_samples, save_dir):
    """
    Saves the representative data with shape (samples, *input_shape).

    Args:
        n_samples (int): The number of samples to be saved.
        save_dir (str): The directory where the data should be saved.
    """
    data_x = self.dataset.train_x[:n_samples]
    os.makedirs(save_dir, exist_ok=True)
    np.save(os.path.join(save_dir, 'representative_data.npy'), data_x)

save_weights

save_weights(save_dir)

Saves the model weights.

Parameters:

Name Type Description Default
save_dir str

The directory where the model weights should be saved.

required
Source code in edgemark/models/platforms/TensorFlow/model_template.py
282
283
284
285
286
287
288
289
290
def save_weights(self, save_dir):
    """
    Saves the model weights.

    Args:
        save_dir (str): The directory where the model weights should be saved.
    """
    os.makedirs(save_dir, exist_ok=True)
    self.model.save_weights(os.path.join(save_dir, "weights"))

set_configs

set_configs(cfg)

Sets the configs from the given dictionary.

Note: The changed configs won't affect the data, model, or any other loaded attributes. In case you want to change them, you should call the corresponding functions.

Parameters:

Name Type Description Default
cfg dict

The configurations.

required
Source code in edgemark/models/platforms/TensorFlow/model_template.py
36
37
38
39
40
41
42
43
44
45
46
def set_configs(self, cfg):
    """
    Sets the configs from the given dictionary.

    Note: The changed configs won't affect the data, model, or any other loaded attributes.
    In case you want to change them, you should call the corresponding functions.

    Args:
        cfg (dict): The configurations.
    """
    raise NotImplementedError

train_model

train_model(fine_tuning=False, tensorboard_log_dir=None, best_weights_dir=None, use_wandb=False)

Trains the model.

Parameters:

Name Type Description Default
fine_tuning bool

If True, the model will be trained for fine-tuning. Defaults to False.

False
tensorboard_log_dir str

The directory where the logs should be saved. If None, the logs won't be saved. Defaults to None.

None
best_weights_dir str

The directory where the best weights should be saved. If None, the best weights won't be saved. Defaults to None.

None
use_wandb bool

If True, the training progress will be logged to W&B. Defaults to False.

False

Returns:

Type Description
Optional[History]

The training history or None.

Source code in edgemark/models/platforms/TensorFlow/model_template.py
61
62
63
64
65
66
67
68
69
70
71
72
73
74
def train_model(self, fine_tuning=False, tensorboard_log_dir=None, best_weights_dir=None, use_wandb=False):
    """
    Trains the model.

    Args:
        fine_tuning (bool, optional): If True, the model will be trained for fine-tuning. Defaults to False.
        tensorboard_log_dir (str, optional): The directory where the logs should be saved. If None, the logs won't be saved. Defaults to None.
        best_weights_dir (str, optional): The directory where the best weights should be saved. If None, the best weights won't be saved. Defaults to None.
        use_wandb (bool, optional): If True, the training progress will be logged to W&B. Defaults to False.

    Returns:
        Optional[tf.keras.callbacks.History]: The training history or None.
    """
    raise NotImplementedError