|
Java example source code file (ParameterAveragingTrainingWorkerStats.java)
The ParameterAveragingTrainingWorkerStats.java Java example source codepackage org.deeplearning4j.spark.impl.paramavg.stats; import lombok.Data; import org.deeplearning4j.spark.api.stats.SparkTrainingStats; import org.nd4j.linalg.util.ArrayUtil; import java.util.*; /** * Statistics colected by {@link org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingWorker} instances * * @author Alex Black */ @Data public class ParameterAveragingTrainingWorkerStats implements SparkTrainingStats { private int[] parameterAveragingWorkerBroadcastGetValueTimeMs; private int[] parameterAveragingWorkerInitTimeMs; private int[] parameterAveragingWorkerFitTimesMs; private static Set<String> columnNames = Collections.unmodifiableSet( new LinkedHashSet<>(Arrays.asList( "ParameterAveragingWorkerBroadcastGetValueTimeMs", "ParameterAveragingWorkerInitTimeMs", "ParameterAveragingWorkerFitTimesMs" ))); public ParameterAveragingTrainingWorkerStats(int parameterAveragingWorkerBroadcastGetValueTimeMs, int parameterAveragingWorkerInitTimeMs, int[] parameterAveragingWorkerFitTimesMs){ this.parameterAveragingWorkerBroadcastGetValueTimeMs = new int[]{parameterAveragingWorkerBroadcastGetValueTimeMs}; this.parameterAveragingWorkerInitTimeMs = new int[]{parameterAveragingWorkerInitTimeMs}; this.parameterAveragingWorkerFitTimesMs = parameterAveragingWorkerFitTimesMs; } @Override public Set<String> getKeySet() { return columnNames; } @Override public Object getValue(String key) { switch(key){ case "ParameterAveragingWorkerBroadcastGetValueTimeMs": return parameterAveragingWorkerBroadcastGetValueTimeMs; case "ParameterAveragingWorkerInitTimeMs": return parameterAveragingWorkerInitTimeMs; case "ParameterAveragingWorkerFitTimesMs": return parameterAveragingWorkerFitTimesMs; default: throw new IllegalArgumentException("Unknown key: \"" + key + "\""); } } @Override public void addOtherTrainingStats(SparkTrainingStats other) { if(!(other instanceof ParameterAveragingTrainingWorkerStats)) throw new IllegalArgumentException("Cannot merge ParameterAveragingTrainingWorkerStats with " + (other != null ? other.getClass() : null)); ParameterAveragingTrainingWorkerStats o = (ParameterAveragingTrainingWorkerStats)other; this.parameterAveragingWorkerBroadcastGetValueTimeMs = ArrayUtil.combine(parameterAveragingWorkerBroadcastGetValueTimeMs,o.parameterAveragingWorkerBroadcastGetValueTimeMs); this.parameterAveragingWorkerInitTimeMs = ArrayUtil.combine(parameterAveragingWorkerInitTimeMs, o.parameterAveragingWorkerInitTimeMs); this.parameterAveragingWorkerFitTimesMs = ArrayUtil.combine(parameterAveragingWorkerFitTimesMs, o.parameterAveragingWorkerFitTimesMs); } @Override public SparkTrainingStats getNestedTrainingStats(){ return null; } @Override public String statsAsString() { StringBuilder sb = new StringBuilder(); String f = SparkTrainingStats.DEFAULT_PRINT_FORMAT; sb.append(String.format(f,"ParameterAveragingWorkerBroadcastGetValueTimeMs")); if(parameterAveragingWorkerBroadcastGetValueTimeMs == null ) sb.append("-\n"); else sb.append(Arrays.toString(parameterAveragingWorkerBroadcastGetValueTimeMs)).append("\n"); sb.append(String.format(f,"ParameterAveragingWorkerInitTimeMs")); if(parameterAveragingWorkerInitTimeMs == null ) sb.append("-\n"); else sb.append(Arrays.toString(parameterAveragingWorkerInitTimeMs)).append("\n"); sb.append(String.format(f,"ParameterAveragingWorkerFitTimesMs")); if(parameterAveragingWorkerFitTimesMs == null ) sb.append("-\n"); else sb.append(Arrays.toString(parameterAveragingWorkerFitTimesMs)).append("\n"); return sb.toString(); } public static class ParameterAveragingTrainingWorkerStatsHelper { private long broadcastStartTime; private long broadcastEndTime; private long initEndTime; private long lastFitStartTime; //TODO replace with fast int collection (no boxing) private List<Integer> fitTimes = new ArrayList<>(); public void logBroadcastGetValueStart(){ broadcastStartTime = System.currentTimeMillis(); } public void logBroadcastGetValueEnd(){ broadcastEndTime = System.currentTimeMillis(); } public void logInitEnd(){ initEndTime = System.currentTimeMillis(); } public void logFitStart(){ lastFitStartTime = System.currentTimeMillis(); } public void logFitEnd(){ long now = System.currentTimeMillis(); fitTimes.add((int)(now - lastFitStartTime)); } public ParameterAveragingTrainingWorkerStats build(){ int bcast = (int)(broadcastEndTime - broadcastStartTime); int init = (int)(initEndTime - broadcastEndTime); //Init starts at same time that broadcast ends int[] fitTimesArr = new int[fitTimes.size()]; for( int i=0; i<fitTimesArr.length; i++ ) fitTimesArr[i] = fitTimes.get(i); return new ParameterAveragingTrainingWorkerStats(bcast, init, fitTimesArr); } } } Other Java examples (source code examples)Here is a short list of links related to this Java ParameterAveragingTrainingWorkerStats.java source code file: |
... this post is sponsored by my books ... | |
#1 New Release! |
FP Best Seller |
Copyright 1998-2021 Alvin Alexander, alvinalexander.com
All Rights Reserved.
A percentage of advertising revenue from
pages under the /java/jwarehouse
URI on this website is
paid back to open source projects.