diff --git a/src/differentiate.jl b/src/differentiate.jl index 16aba4c..4ceb603 100644 --- a/src/differentiate.jl +++ b/src/differentiate.jl @@ -9,14 +9,38 @@ export differentiate # ################################################################# -differentiate(ex::SymbolicVariable, wrt::SymbolicVariable) = (ex == wrt) ? 1 : 0 +macro makeDerivative(ex, wrt) + if haskey(storedFunctions, ex) + #in that case, ex is a function that was defined already + args = storedFunctions[ex].args + code = storedFunctions[ex].code + derivativeCode = differentiate(code, wrt) + res = eval(:($args -> $derivativeCode)) + return res + else + derivativeCode = differentiate(ex, wrt) + res = eval(:($wrt -> $derivativeCode)) + return res + end +end + +function differentiate(ex::SymbolicVariable, wrt::SymbolicVariable) + if haskey(storedFunctions, ex) + return differentiate(storedFunctions[ex].code, wrt) + end + (ex == wrt) ? 1 : 0 +end + differentiate(ex::Number, wrt::SymbolicVariable) = 0 function differentiate(ex::Expr,wrt) + #println("about to differentiate: ",ex," with respect to ", wrt) if ex.head != :call error("Unrecognized expression $ex") end + #println(SymbolParameter(ex.args[1])) + #println(ex.args[2:]) simplify(differentiate(SymbolParameter(ex.args[1]), ex.args[2:end], wrt)) end diff --git a/src/storingFunctions.jl b/src/storingFunctions.jl new file mode 100644 index 0000000..1ab943e --- /dev/null +++ b/src/storingFunctions.jl @@ -0,0 +1,34 @@ +export StoredFunction + + +global storedFunctions = Dict() + +type StoredFunction + args + code +end + +macro define(functionName, args, code) + storedFunctions[functionName] = StoredFunction(args,code) + :(global $functionName = $args -> $code) +end + + +function testStoredFunction(functionName::Expr) + print(storedFunctions) + return haskey(storedFunctions, functionName) +end + + +## function differentiate(ex::Expr,wrt) +## print("hello") +## print(haskey(storedFunctions, ex)) +## if haskey(storedFunctions, ex) +## print("differentiating a user-defined function") +## differentiate(storedFunctions[ex].code, wrt) +## end +## if ex.head != :call +## error("Unrecognized expression $ex") +## end +## simplify(differentiate(SymbolParameter(ex.args[1]), ex.args[2:end], wrt)) +## end