Models package

Submodules

Models.AFNO module

class Models.AFNO.AFNONet(img_size=(720, 1440), patch_size=(16, 16), in_chans=2, out_chans=2, embed_dim=768, depth=12, mlp_ratio=4.0, drop_rate=0.0, drop_path_rate=0.0, num_blocks=16, sparsity_threshold=0.01, hard_thresholding_fraction=1.0, **model_kwargs)

Bases: Module

forward(x)

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

forward_features(x)
no_weight_decay()

Models.FNN module

class Models.FNN.FNN(in_channels: int, out_channels: tensor, domain_shape: tensor, hidden_channels=256, n_layers=4, **model_kwargs)

Bases: Module

Simple feedforward network designed as component in ACW.

forward(x)

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Models.GP module

class Models.GP.EGP(train_x, train_y, mean_model_class: Module, in_channels: int, domain_shape: tensor, **model_kwargs)

Bases: ExactGP

forward(x)

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class Models.GP.GPWrapper(mean_model_class: Module, gp_model_class: Module, in_channels: int, domain_shape: tensor, num_data: int, dropout_rate: float, dataset, **model_kwargs)

Bases: Module

A wrapper for a GP model that properly shapes input and output.

forward(x)

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

initialize_normalization_stats(batch_size=32)

Compute and set normalization statistics for input and output in the model from GridStationDataset.

Parameters:

model (nn.Module): The model with registered input/output mean/std buffers. dataset (GridStationDataset): Dataset instance returning (x, y). batch_size (int): Batch size for statistics computation.

class Models.GP.VariationalGP(num_latents: int, input_dim: int, output_dim: int, m_inducing_points: int, **model_kwargs)

Bases: ApproximateGP

forward(x)

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Models.GST module

class Models.GST.CGST(in_channels: int, domain_shape: tensor)

Bases: Module

Covariance for the Gaussian Spatio-Temporal (GST) model. Estimates the mean and covariance based on past observations by assuming no temporal dependence. Model is non-parametric and does not change with training.

forward(x)

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class Models.GST.MGST(in_channels: int, out_channels: tensor, domain_shape: tensor, **model_kwargs)

Bases: Module

Mean for the Gaussian Spatio-Temporal (GST) model. Estimates the mean and covariance based on past observations by assuming no temporal dependence. Model is non-parametric and does not change with training.

forward(x)

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Models.VGP_optimizer module

class Models.VGP_optimizer.VGPOptimizer(model, hyperparameter_optimizer, ngd_lr=0.1, hyper_lr=0.01)

Bases: object

Custom optimizer with adjustable hyperparameter optimization and efficient NGD optimization for VGP.

load_state_dict(state_dict)
state_dict()
step()
zero_grad()

Models.neural_models module

class Models.neural_models.ACW(mean_model_class: Module, covariance_model_class: Module, in_channels: int, domain_shape: tensor, **model_kwargs)

Bases: Module

Adjustable Channel Wrapper: a wrapper for neural networks that provides for explicit specification of the in- and out-channels.

forward(x)

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class Models.neural_models.BNN(probabilistic_model: ProbabilisticModel, in_channels: int, model_class: Module, **model_kwargs)

Bases: ACW

A Bayesian Neural Network that is a Neural Network (ACW) with output specified by the given probabilistic model.

sample(x, n)
class Models.neural_models.PNN(probabilistic_model: ProbabilisticModel, mean_model_class: Module, covariance_model_class: Module, in_channels: int, domain_shape: tensor, **model_kwargs)

Bases: ACW

A probabilistic Neural Network that is a Neural Network (ACW) with output specified by the given probabilistic model.

get_prob_model()