Describes an individual function call as part of a computation. More...
#include <ad_computenode.h>
Public Member Functions | |
STensor | autodiff (ComputeNodePtr ptr_to_me, STensorId const input_id, STensorId const result_id, AutodiffRescale do_rescale=AutodiffRescale::n()) |
Does the heavy lifting to compute the total derivative of the calling tensor (identified by result_id ) with respect to one of its input tensors (identified by input_id ). More... | |
ComputeNode (std::string &&opname_, Vec< Pair< ComputeNodePtr, Size > > &&input_nodes_, Vec< STensorId > &&output_ids_, AdjointEvaluator &&func_, Vec< STensor > &&output_shapes_, Vec< AsyncCached< STensor > > &&cached_tensors_={}) | |
Standard ctor. More... | |
void | draw (std::ostream &out) const |
Writes to the ostream a dot-encoded representation of the callgraph. More... | |
Vec< AsyncCached< STensor > > const & | get_cached_tensors () const |
Returns a constant view of the cached tensors for the adjoint evaluator. More... | |
STensor | get_input_adjoint (Size const input_number, STensorId const result_id, AutodiffRescale do_rescale) |
Returns the adjoint of the specified input tensor with respect to the specified output ID from this subtree of the calculation. More... | |
STensor const & | get_output_adjoint (Size const output_number, STensorId const result_id) |
Returns the adjoint of a specific output tensor with respect to some result. More... | |
Vec< STensorId > const & | get_output_ids () const |
Returns a constant view of the STensorId s of the output tensors created in this compute node. More... | |
Size | get_output_number (STensorId const output_id) const |
Returns the number of an output tensor given its STensorId. More... | |
Private Member Functions | |
void | build_links_upstream (ComputeNodePtr ptr_to_me) |
Double-links the callgraphs by walking upstreams and setting the output_nodes maps of those nodes upstream from this node. More... | |
void | draw_recursive (std::ostream &out, std::set< ComputeNode const * > &seen_ptrs) const |
Recursive helper for ComputeNode::draw(). More... | |
ComputeNodePtr | find_origin_node (ComputeNodePtr ptr_to_me, STensorId const input_id) |
Returns the origin node upstream of this which has generated the tensor with ID input_id . More... | |
std::string | full_name (bool newlines=false) const |
Returns a full name description placed in " suitable for feeding into dot . More... | |
Private Attributes | |
std::map< Pair< Size, STensorId >, AsyncCached< STensor > > | adjoint_cache |
Cache of adjoints of output tensors created by this compute node. More... | |
Bool | built_links_upstream = false |
Set to true by the first call to build_links_upstream . More... | |
Vec< AsyncCached< STensor > > | cached_tensors |
Vector of tensors used by the adjoint evaluator and cached here from the evaluation of this compute node. More... | |
AdjointEvaluator | func |
Function which, when passed a reference to this input node, an index j into this ComputeNode input nodes input_number and a result tensor ID R with respect to which to build the adjoint, i.e. constructs. More... | |
Vec< Pair< ComputeNodePtr, Size > > | input_nodes |
Stores the compute nodes which provided input to this compute node and an index into that compute nodes output tensors ids vector. More... | |
AutodiffRescale | node_do_rescale = AutodiffRescale::n() |
Temporarily sets this node to rescale its outputs to avoid accumulation of overly large numbers. More... | |
std::string | opname |
Stores a string description of the operation represented by this node. More... | |
Vec< STensorId > | output_ids |
IDs of the output tensors created during this compute node evaluation. More... | |
std::map< ComputeNodeWPtr, std::set< Size >, std::owner_less< ComputeNodeWPtr > > | output_nodes |
Once a backward sweep is starting, this map will collect compute nodes which obtained as input one of the outputs of this compute node. More... | |
Vec< STensor > | output_shapes |
Empty tensors of the shapes of the output tensors. More... | |
Describes an individual function call as part of a computation.