import keras
import tensorflow as tf
class Linear(keras.layers.Layer):
def __init__(self, input_dim=32, output_dim=32):
super().__init__()
w_init = tf.random_normal_initializer()
self.w = tf.Variable(
initial_value=w_init(shape=(input_dim, output_dim), dtype="float32"),
trainable=True,
)
# print(self.w)
b_init = tf.zeros_initializer()
self.b = tf.Variable(
initial_value=b_init(shape=(output_dim,), dtype="float32"), trainable=True
)
def call(self, inputs):
#矩阵相乘 Amn*Bnp 的维度是m*p
return tf.matmul(inputs, self.w) + self.b
x = tf.ones((3, 2))
linear_layer = Linear(2, 5)
#函数式变成直接用 linear_layer(x)
y= linear_layer(x)
print(y.shape)
#更本质通用的用法
y = linear_layer.call(x)
print(y.shape)