|
35 | 35 | "cell_type": "code",
|
36 | 36 | "execution_count": 1,
|
37 | 37 | "metadata": {},
|
38 |
| - "outputs": [], |
| 38 | + "outputs": [ |
| 39 | + { |
| 40 | + "name": "stderr", |
| 41 | + "output_type": "stream", |
| 42 | + "text": [ |
| 43 | + "/home/ben/miniconda3/envs/pytorch17/lib/python3.8/site-packages/torchtext-0.9.0a0+c38fd42-py3.8-linux-x86_64.egg/torchtext/data/field.py:150: UserWarning: Field class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.\n", |
| 44 | + " warnings.warn('{} class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.'.format(self.__class__.__name__), UserWarning)\n", |
| 45 | + "/home/ben/miniconda3/envs/pytorch17/lib/python3.8/site-packages/torchtext-0.9.0a0+c38fd42-py3.8-linux-x86_64.egg/torchtext/data/field.py:150: UserWarning: LabelField class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.\n", |
| 46 | + " warnings.warn('{} class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.'.format(self.__class__.__name__), UserWarning)\n" |
| 47 | + ] |
| 48 | + } |
| 49 | + ], |
39 | 50 | "source": [
|
40 | 51 | "import torch\n",
|
41 | 52 | "from torchtext import data\n",
|
|
46 | 57 | "torch.manual_seed(SEED)\n",
|
47 | 58 | "torch.backends.cudnn.deterministic = True\n",
|
48 | 59 | "\n",
|
49 |
| - "TEXT = data.Field(tokenize = 'spacy', include_lengths = True)\n", |
| 60 | + "TEXT = data.Field(tokenize = 'spacy',\n", |
| 61 | + " tokenizer_language = 'en_core_web_sm',\n", |
| 62 | + " include_lengths = True)\n", |
| 63 | + "\n", |
50 | 64 | "LABEL = data.LabelField(dtype = torch.float)"
|
51 | 65 | ]
|
52 | 66 | },
|
|
61 | 75 | "cell_type": "code",
|
62 | 76 | "execution_count": 2,
|
63 | 77 | "metadata": {},
|
64 |
| - "outputs": [], |
| 78 | + "outputs": [ |
| 79 | + { |
| 80 | + "name": "stderr", |
| 81 | + "output_type": "stream", |
| 82 | + "text": [ |
| 83 | + "/home/ben/miniconda3/envs/pytorch17/lib/python3.8/site-packages/torchtext-0.9.0a0+c38fd42-py3.8-linux-x86_64.egg/torchtext/data/example.py:78: UserWarning: Example class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.\n", |
| 84 | + " warnings.warn('Example class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.', UserWarning)\n" |
| 85 | + ] |
| 86 | + } |
| 87 | + ], |
65 | 88 | "source": [
|
66 | 89 | "from torchtext import datasets\n",
|
67 | 90 | "\n",
|
|
133 | 156 | "cell_type": "code",
|
134 | 157 | "execution_count": 5,
|
135 | 158 | "metadata": {},
|
136 |
| - "outputs": [], |
| 159 | + "outputs": [ |
| 160 | + { |
| 161 | + "name": "stderr", |
| 162 | + "output_type": "stream", |
| 163 | + "text": [ |
| 164 | + "/home/ben/miniconda3/envs/pytorch17/lib/python3.8/site-packages/torchtext-0.9.0a0+c38fd42-py3.8-linux-x86_64.egg/torchtext/data/iterator.py:48: UserWarning: BucketIterator class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.\n", |
| 165 | + " warnings.warn('{} class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.'.format(self.__class__.__name__), UserWarning)\n" |
| 166 | + ] |
| 167 | + } |
| 168 | + ], |
137 | 169 | "source": [
|
138 | 170 | "BATCH_SIZE = 64\n",
|
139 | 171 | "\n",
|
|
204 | 236 | "\n",
|
205 | 237 | "As we are passing the lengths of our sentences to be able to use packed padded sequences, we have to add a second argument, `text_lengths`, to `forward`. \n",
|
206 | 238 | "\n",
|
207 |
| - "Before we pass our embeddings to the RNN, we need to pack them, which we do with `nn.utils.rnn.packed_padded_sequence`. This will cause our RNN to only process the non-padded elements of our sequence. The RNN will then return `packed_output` (a packed sequence) as well as the `hidden` and `cell` states (both of which are tensors). Without packed padded sequences, `hidden` and `cell` are tensors from the last element in the sequence, which will most probably be a pad token, however when using packed padded sequences they are both from the last non-padded element in the sequence. \n", |
| 239 | + "Before we pass our embeddings to the RNN, we need to pack them, which we do with `nn.utils.rnn.packed_padded_sequence`. This will cause our RNN to only process the non-padded elements of our sequence. The RNN will then return `packed_output` (a packed sequence) as well as the `hidden` and `cell` states (both of which are tensors). Without packed padded sequences, `hidden` and `cell` are tensors from the last element in the sequence, which will most probably be a pad token, however when using packed padded sequences they are both from the last non-padded element in the sequence. Note that the `lengths` argument of `packed_padded_sequence` must be a CPU tensor so we explicitly make it one by using `.to('cpu')`.\n", |
208 | 240 | "\n",
|
209 | 241 | "We then unpack the output sequence, with `nn.utils.rnn.pad_packed_sequence`, to transform it from a packed sequence to a tensor. The elements of `output` from padding tokens will be zero tensors (tensors where every element is zero). Usually, we only have to unpack output if we are going to use it later on in the model. Although we aren't in this case, we still unpack the sequence just to show how it is done.\n",
|
210 | 242 | "\n",
|
|
246 | 278 | " #embedded = [sent len, batch size, emb dim]\n",
|
247 | 279 | " \n",
|
248 | 280 | " #pack sequence\n",
|
249 |
| - " packed_embedded = nn.utils.rnn.pack_padded_sequence(embedded, text_lengths)\n", |
| 281 | + " # lengths need to be on CPU!\n", |
| 282 | + " packed_embedded = nn.utils.rnn.pack_padded_sequence(embedded, text_lengths.to('cpu'))\n", |
250 | 283 | " \n",
|
251 | 284 | " packed_output, (hidden, cell) = self.rnn(packed_embedded)\n",
|
252 | 285 | " \n",
|
|
383 | 416 | " [-0.8555, -0.7208, 1.3755, ..., 0.0825, -1.1314, 0.3997],\n",
|
384 | 417 | " [-0.0382, -0.2449, 0.7281, ..., -0.1459, 0.8278, 0.2706],\n",
|
385 | 418 | " ...,\n",
|
386 |
| - " [-0.0614, -0.0516, -0.6159, ..., -0.0354, 0.0379, -0.1809],\n", |
387 |
| - " [ 0.1885, -0.1690, 0.1530, ..., -0.2077, 0.5473, -0.4517],\n", |
388 |
| - " [-0.1182, -0.4701, -0.0600, ..., 0.7991, -0.0194, 0.4785]])" |
| 419 | + " [ 0.6783, 0.0488, 0.5860, ..., 0.2680, -0.0086, 0.5758],\n", |
| 420 | + " [-0.6208, -0.0480, -0.1046, ..., 0.3718, 0.1225, 0.1061],\n", |
| 421 | + " [-0.6553, -0.6292, 0.9967, ..., 0.2278, -0.1975, 0.0857]])" |
389 | 422 | ]
|
390 | 423 | },
|
391 | 424 | "execution_count": 10,
|
|
421 | 454 | " [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n",
|
422 | 455 | " [-0.0382, -0.2449, 0.7281, ..., -0.1459, 0.8278, 0.2706],\n",
|
423 | 456 | " ...,\n",
|
424 |
| - " [-0.0614, -0.0516, -0.6159, ..., -0.0354, 0.0379, -0.1809],\n", |
425 |
| - " [ 0.1885, -0.1690, 0.1530, ..., -0.2077, 0.5473, -0.4517],\n", |
426 |
| - " [-0.1182, -0.4701, -0.0600, ..., 0.7991, -0.0194, 0.4785]])\n" |
| 457 | + " [ 0.6783, 0.0488, 0.5860, ..., 0.2680, -0.0086, 0.5758],\n", |
| 458 | + " [-0.6208, -0.0480, -0.1046, ..., 0.3718, 0.1225, 0.1061],\n", |
| 459 | + " [-0.6553, -0.6292, 0.9967, ..., 0.2278, -0.1975, 0.0857]])\n" |
427 | 460 | ]
|
428 | 461 | }
|
429 | 462 | ],
|
|
638 | 671 | "execution_count": 18,
|
639 | 672 | "metadata": {},
|
640 | 673 | "outputs": [
|
| 674 | + { |
| 675 | + "name": "stderr", |
| 676 | + "output_type": "stream", |
| 677 | + "text": [ |
| 678 | + "/home/ben/miniconda3/envs/pytorch17/lib/python3.8/site-packages/torchtext-0.9.0a0+c38fd42-py3.8-linux-x86_64.egg/torchtext/data/batch.py:23: UserWarning: Batch class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.\n", |
| 679 | + " warnings.warn('{} class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.'.format(self.__class__.__name__), UserWarning)\n" |
| 680 | + ] |
| 681 | + }, |
641 | 682 | {
|
642 | 683 | "name": "stdout",
|
643 | 684 | "output_type": "stream",
|
644 | 685 | "text": [
|
645 |
| - "Epoch: 01 | Epoch Time: 0m 28s\n", |
646 |
| - "\tTrain Loss: 0.648 | Train Acc: 62.05%\n", |
647 |
| - "\t Val. Loss: 0.620 | Val. Acc: 66.72%\n", |
648 |
| - "Epoch: 02 | Epoch Time: 0m 27s\n", |
649 |
| - "\tTrain Loss: 0.622 | Train Acc: 66.51%\n", |
650 |
| - "\t Val. Loss: 0.669 | Val. Acc: 62.83%\n", |
651 |
| - "Epoch: 03 | Epoch Time: 0m 27s\n", |
652 |
| - "\tTrain Loss: 0.586 | Train Acc: 69.01%\n", |
653 |
| - "\t Val. Loss: 0.522 | Val. Acc: 75.52%\n", |
654 |
| - "Epoch: 04 | Epoch Time: 0m 27s\n", |
655 |
| - "\tTrain Loss: 0.415 | Train Acc: 82.02%\n", |
656 |
| - "\t Val. Loss: 0.457 | Val. Acc: 77.10%\n", |
657 |
| - "Epoch: 05 | Epoch Time: 0m 27s\n", |
658 |
| - "\tTrain Loss: 0.335 | Train Acc: 86.15%\n", |
659 |
| - "\t Val. Loss: 0.305 | Val. Acc: 87.15%\n" |
| 686 | + "Epoch: 01 | Epoch Time: 0m 36s\n", |
| 687 | + "\tTrain Loss: 0.673 | Train Acc: 58.05%\n", |
| 688 | + "\t Val. Loss: 0.619 | Val. Acc: 64.97%\n", |
| 689 | + "Epoch: 02 | Epoch Time: 0m 36s\n", |
| 690 | + "\tTrain Loss: 0.611 | Train Acc: 66.33%\n", |
| 691 | + "\t Val. Loss: 0.510 | Val. Acc: 74.32%\n", |
| 692 | + "Epoch: 03 | Epoch Time: 0m 37s\n", |
| 693 | + "\tTrain Loss: 0.484 | Train Acc: 77.04%\n", |
| 694 | + "\t Val. Loss: 0.397 | Val. Acc: 82.95%\n", |
| 695 | + "Epoch: 04 | Epoch Time: 0m 37s\n", |
| 696 | + "\tTrain Loss: 0.384 | Train Acc: 83.57%\n", |
| 697 | + "\t Val. Loss: 0.407 | Val. Acc: 83.23%\n", |
| 698 | + "Epoch: 05 | Epoch Time: 0m 37s\n", |
| 699 | + "\tTrain Loss: 0.314 | Train Acc: 86.98%\n", |
| 700 | + "\t Val. Loss: 0.314 | Val. Acc: 86.36%\n" |
660 | 701 | ]
|
661 | 702 | }
|
662 | 703 | ],
|
|
701 | 742 | "name": "stdout",
|
702 | 743 | "output_type": "stream",
|
703 | 744 | "text": [
|
704 |
| - "Test Loss: 0.308 | Test Acc: 87.07%\n" |
| 745 | + "Test Loss: 0.334 | Test Acc: 85.28%\n" |
705 | 746 | ]
|
706 | 747 | }
|
707 | 748 | ],
|
|
744 | 785 | "outputs": [],
|
745 | 786 | "source": [
|
746 | 787 | "import spacy\n",
|
747 |
| - "nlp = spacy.load('en')\n", |
| 788 | + "nlp = spacy.load('en_core_web_sm')\n", |
748 | 789 | "\n",
|
749 | 790 | "def predict_sentiment(model, sentence):\n",
|
750 | 791 | " model.eval()\n",
|
|
773 | 814 | {
|
774 | 815 | "data": {
|
775 | 816 | "text/plain": [
|
776 |
| - "0.005683214403688908" |
| 817 | + "0.05380420759320259" |
777 | 818 | ]
|
778 | 819 | },
|
779 | 820 | "execution_count": 21,
|
|
800 | 841 | {
|
801 | 842 | "data": {
|
802 | 843 | "text/plain": [
|
803 |
| - "0.9926869869232178" |
| 844 | + "0.94941645860672" |
804 | 845 | ]
|
805 | 846 | },
|
806 | 847 | "execution_count": 22,
|
|
838 | 879 | "name": "python",
|
839 | 880 | "nbconvert_exporter": "python",
|
840 | 881 | "pygments_lexer": "ipython3",
|
841 |
| - "version": "3.7.0" |
| 882 | + "version": "3.8.5" |
842 | 883 | }
|
843 | 884 | },
|
844 | 885 | "nbformat": 4,
|
|
0 commit comments