1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83
| import os import visdom import numpy as np import time
class Visualizer(object): """ 封装了visdom的基本操作,但是你仍然可以通过`self.vis.function` 或者`self.function`调用原生的visdom接口 比如 self.text('hello visdom') self.histogram(t.randn(1000)) self.line(t.arange(0, 10),t.arange(1, 11)) """ def __init__(self, env="default", **kwargs): self.vis = visdom.Visdom(env=env, **kwargs) self.env = env self.index = {} self.log_text = "" def reinit(self, env="default", **kwargs): """ 修改visdom的配置 """ self.vis = visdom.Visdom(env=env, **kwargs) self.env = env return self def plot_many(self, d): """ 一次plot多个 @params d: dict (name, value) i.e. ("loss", 0.11) """ for k, v in d.iteritems(): self.plot(k, v) def img_many(self, d): for k, v in d.iteritems(): self.img(k, v) def plot(self, name, y, **kwargs): x = self.index.get(name, 0) self.vis.line(Y=np.array([y]), X=np.array([x]), win=name, opts=dict(title=name), update=None if x == 0 else "append", **kwargs ) self.index[name] = x + 1 def img(self, name, img_, **kwargs): """ self.img("input_img", t.Tensor(64, 64)) self.img("input_imgs", t.Tensor(3, 64, 64)) self.img("input_imgs", t.Tensor(100, 1, 64, 64)) self.img("input_imgs", t.Tensor(100, 3, 64, 64), nrows=10) """ self.vis.images(img_, win=name, opts=dict(title=name), **kwargs ) def log(self, info, win="log_text"): """ self.log({"loss": 1, "lr": 0.0001}) """ self.log_text += ("[{time}] {info} <br>".format( time=time.strftime("%m%d_%H%M%S"), info=info)) self.vis.text(self.log_text, win) def __getattr__(self, name): """ self.function 等价于self.vis.function 自定义的plot, image, log, plot_many等除外 """ return getattr(self.vis, name)
|