PyTorch

Loading Models

Loading .pth file

# for a .pth file
def load_model(model, file_name):
    def reformat_dict(state):
        reformat_state = {}
        for key in state:
            new_key = key.replace('module.', '')
            reformat_state[new_key] = state[key]
        return reformat_state
    state = torch.load(file_name,  map_location='cpu')['net']
    reformat_state = reformat_dict(state)
    model.load_state_dict(reformat_state)

Saving Models

Loading a .pytorch_state file

torch.save(model.state_dict(), 'micro_large_conv_and_fc_ttq_x0_32.pytorch_state')

See Google Collab Tools page