# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import paddle def update_vgg16_params(model_path): param_state_dict = paddle.load(model_path) # first conv weight name _conv_block_1._conv_1.weight, shape is [64, 3, ,3, 3] # first fc weight name: _fc1.weight, shape is [25088, 4096] for k, v in param_state_dict.items(): print(k, v.shape) # # first weight weight = param_state_dict['_conv_block_1._conv_1.weight'] # [64, 3,3,3] print('ori shape: ', weight.shape) zeros_pad = paddle.zeros((64, 1, 3, 3)) param_state_dict['_conv_block_1._conv_1.weight'] = paddle.concat( [weight, zeros_pad], axis=1) print('shape after padding', param_state_dict['_conv_block_1._conv_1.weight'].shape) # fc1 weight = param_state_dict['_fc1.weight'] weight = paddle.transpose(weight, [1, 0]) print('after transpose: ', weight.shape) weight = paddle.reshape(weight, (4096, 512, 7, 7)) print('after reshape: ', weight.shape) weight = weight[0:512, :, 2:5, 2:5] print('after crop: ', weight.shape) param_state_dict['_conv_6.weight'] = weight del param_state_dict['_fc1.weight'] del param_state_dict['_fc1.bias'] del param_state_dict['_fc2.weight'] del param_state_dict['_fc2.bias'] del param_state_dict['_out.weight'] del param_state_dict['_out.bias'] paddle.save(param_state_dict, 'VGG16_pretrained.pdparams') if __name__ == "__main__": paddle.set_device('cpu') model_path = '~/.paddleseg/pretrained_model/dygraph/VGG16_pretrained.pdparams' update_vgg16_params(model_path)