-
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathSchNetConfig.js
109 lines (103 loc) · 3.05 KB
/
SchNetConfig.js
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import React from 'react'
import { TextField } from '@mui/material'
import PropTypes from 'prop-types'
import { camelToNaturalString } from '../../../utils'
/**
* SchNet-specific configurable parameters
* @type {{readoutSize: {min: number, explanation: string}, depth: {min: number, explanation: string}, embeddingDimension: {min: number, explanation: string}}}
*/
const settableSizes = {
depth: {
min: 1,
explanation: 'This defines how many layers your net will have',
},
embeddingDimension: {
min: 1,
explanation: 'This defines how many values each node corresponds to',
},
readoutSize: {
min: 1,
explanation:
'Number of nodes in the regressional multilayer perceptron part of the network',
},
}
/**
* Configuration of SchNet-specific parameters
* @param schnetParams initial values
* @param updateFunc callback to update a parameter
* @param errorSignal callback whether current configuration is invalid
* @param hoverFunc callback for hovering
* @param leaveFunc callback for mouse pointer leaving component
* @returns {JSX.Element}
*/
export default function SchNetConfig({
schnetParams,
updateFunc,
errorSignal,
hoverFunc,
leaveFunc,
}) {
const [sizes, setSizes] = React.useState([
schnetParams.depth,
schnetParams.embeddingDimension,
schnetParams.readoutSize,
])
const [sizesError, setSizesError] = React.useState([false, false, false])
React.useEffect(() => {
errorSignal(sizesError.includes(true))
}, [sizesError])
/**
* called when a value is changed
* @param event the event which triggered
* @param i index of configured parameter in sizes
* @param key of parameter
* @param min lower bound for permitted values of the field
*/
const handleChange = (event, i, key, min) => {
const sizesErrorClone = [...sizesError]
sizesErrorClone[i] = event.target.value < min
setSizesError(sizesErrorClone)
if (event.target.value >= min) {
const sizesClone = [...sizes]
sizesClone[i] = event.target.value
updateFunc(key, event.target.value)
setSizes(sizesClone)
}
}
return (
<div>
{Object.entries(settableSizes).map(([key, value], i) => {
return (
<TextField
required
key={i}
id="outlined-number"
label={camelToNaturalString(key)}
type="number"
defaultValue={sizes[i]}
error={sizesError[i]}
helperText={sizesError[i] ? 'Must be a number above zero!' : ''}
onChange={(e) => handleChange(e, i, key, value.min)}
onMouseOver={(e) => {
hoverFunc(e, value.explanation)
}}
onMouseLeave={leaveFunc}
InputLabelProps={{
shrink: true,
}}
sx={{
m: 2,
}}
/>
)
})}
</div>
)
}
SchNetConfig.propTypes = {
schnetParams: PropTypes.object.isRequired,
updateFunc: PropTypes.func.isRequired,
errorSignal: PropTypes.func,
hoverFunc: PropTypes.func,
leaveFunc: PropTypes.func,
}