ThinkChat2.0新版上线,更智能更精彩,支持会话、画图、阅读、搜索等,送10W Token,即刻开启你的AI之旅 广告
RNN的 hidden state ```py class rnn_(torch.nn.Module): def __init__(self, input_size, hidden_size, num_layers, output_size, batch_size): super().__init__() self.rnn = torch.nn.RNN(input_size, hidden_size, num_layers, batch_first=True) def forward(self, x): h = torch.zeros(1, x.size(0), self.hidden_size) out, h = self.rnn(x, h) return out ``` 不在那个位置写,经常会出莫名其妙的错! ***** 手写 dataset 的问题 在__init__中完成input和 target 的张量,在__getitem__中只做取值操作。数据的shape是**(总量, 其它)**,其它例如图片可能是(channel, height, width),minist数据是(28, 28),文字数据是(序列长度)等。 ```py class qohdataset(data.Dataset): """ Dataset must define __getitem__ and __len__ """ def __init__(self, qoh): def padding(ele, num): difference = num - len(ele) for _ in range(difference): ele.append(np.zeros((47,))) self.qoh = qoh for i in self.qoh: if len(i) < 13: padding(i, 13) self.qoh = np.array(self.qoh, dtype=int) print(self.qoh.shape) self.qoh = torch.from_numpy(self.qoh) self.seq = self.qoh[:, 0:12, :] self.tar = self.qoh[:, 1:13, :] def __getitem__(self, index): """ index位置的(x, y), x和y都是tensor Returns one data pair (x and y). """ x = self.seq[index, ...] y = self.tar[index, ...] return x, y def __len__(self): # 0<= index < lens lens = self.qoh.shape[1] return lens ``` ***** dataloader的问题 dataloader获得的是(batch, 其它),其它和 dataset 一致。一般而言,只有在输入序列不一样长的时候才会定义collate_fn,否则直接调用即可 ***** 数据类型是有要求的: float, double, half, short(int16), int(int32), long(int64)