SyTen
syten::STensorImpl::Autodiff::ComputeNode Class Reference

Describes an individual function call as part of a computation. More...

#include <ad_computenode.h>

+ Collaboration diagram for syten::STensorImpl::Autodiff::ComputeNode:

Public Member Functions

STensor autodiff (ComputeNodePtr ptr_to_me, STensorId const input_id, STensorId const result_id, AutodiffRescale do_rescale=AutodiffRescale::n())
 Does the heavy lifting to compute the total derivative of the calling tensor (identified by result_id) with respect to one of its input tensors (identified by input_id). More...
 
 ComputeNode (std::string &&opname_, Vec< Pair< ComputeNodePtr, Size > > &&input_nodes_, Vec< STensorId > &&output_ids_, AdjointEvaluator &&func_, Vec< STensor > &&output_shapes_, Vec< AsyncCached< STensor > > &&cached_tensors_={})
 Standard ctor. More...
 
void draw (std::ostream &out) const
 Writes to the ostream a dot-encoded representation of the callgraph. More...
 
Vec< AsyncCached< STensor > > const & get_cached_tensors () const
 Returns a constant view of the cached tensors for the adjoint evaluator. More...
 
STensor get_input_adjoint (Size const input_number, STensorId const result_id, AutodiffRescale do_rescale)
 Returns the adjoint of the specified input tensor with respect to the specified output ID from this subtree of the calculation. More...
 
STensor const & get_output_adjoint (Size const output_number, STensorId const result_id)
 Returns the adjoint of a specific output tensor with respect to some result. More...
 
Vec< STensorId > const & get_output_ids () const
 Returns a constant view of the STensorIds of the output tensors created in this compute node. More...
 
Size get_output_number (STensorId const output_id) const
 Returns the number of an output tensor given its STensorId. More...
 

Private Member Functions

void build_links_upstream (ComputeNodePtr ptr_to_me)
 Double-links the callgraphs by walking upstreams and setting the output_nodes maps of those nodes upstream from this node. More...
 
void draw_recursive (std::ostream &out, std::set< ComputeNode const * > &seen_ptrs) const
 Recursive helper for ComputeNode::draw(). More...
 
ComputeNodePtr find_origin_node (ComputeNodePtr ptr_to_me, STensorId const input_id)
 Returns the origin node upstream of this which has generated the tensor with ID input_id. More...
 
std::string full_name (bool newlines=false) const
 Returns a full name description placed in " suitable for feeding into dot. More...
 

Private Attributes

std::map< Pair< Size, STensorId >, AsyncCached< STensor > > adjoint_cache
 Cache of adjoints of output tensors created by this compute node. More...
 
Bool built_links_upstream = false
 Set to true by the first call to build_links_upstream. More...
 
Vec< AsyncCached< STensor > > cached_tensors
 Vector of tensors used by the adjoint evaluator and cached here from the evaluation of this compute node. More...
 
AdjointEvaluator func
 Function which, when passed a reference to this input node, an index j into this ComputeNode input nodes input_number and a result tensor ID R with respect to which to build the adjoint, i.e. constructs. More...
 
Vec< Pair< ComputeNodePtr, Size > > input_nodes
 Stores the compute nodes which provided input to this compute node and an index into that compute nodes output tensors ids vector. More...
 
AutodiffRescale node_do_rescale = AutodiffRescale::n()
 Temporarily sets this node to rescale its outputs to avoid accumulation of overly large numbers. More...
 
std::string opname
 Stores a string description of the operation represented by this node. More...
 
Vec< STensorIdoutput_ids
 IDs of the output tensors created during this compute node evaluation. More...
 
std::map< ComputeNodeWPtr, std::set< Size >, std::owner_less< ComputeNodeWPtr > > output_nodes
 Once a backward sweep is starting, this map will collect compute nodes which obtained as input one of the outputs of this compute node. More...
 
Vec< STensoroutput_shapes
 Empty tensors of the shapes of the output tensors. More...
 

Detailed Description

Describes an individual function call as part of a computation.


The documentation for this class was generated from the following files: