Sisällön yleiskatsaus asennettu Säästäminen tf.keras koulutus API: stä Kirjoittaminen Checkpoints Manuaalinen tarkistus Mekaaninen kuorma Viivästyneet kunnostukset Manuaalinen tarkastus tarkastuspisteissä Objektin seuranta Yhteenveto Lause "TensorFlow-mallin tallentaminen" tarkoittaa yleensä yhtä kahdesta asiasta: tarkastuspisteitä tai Pelastettu malli Tarkennuspisteet tallentavat tarkan arvon kaikista parametreista ( Tarkistuspisteet eivät sisällä mitään kuvausta mallin määrittelemästä laskennasta, joten ne ovat yleensä hyödyllisiä vain silloin, kun käytettävissä on lähdekoodi, joka käyttää tallennettuja parametriarvoja. tf.Variable SavedModel-muoto puolestaan sisältää mallin määrittelemän laskennan sarjakuvan parametrien arvojen lisäksi (checkpoint). Mallit tässä muodossa ovat riippumattomia mallin luomasta lähdekoodista. Ne soveltuvat siten käyttöön TensorFlow Servingin, TensorFlow Lite:n, TensorFlow.js:n tai muiden ohjelmointikielten ohjelmien kautta (C, C++, Java, Go, Rust, C# jne. TensorFlow API). Tämä opas kattaa APIs kirjoittamiseen ja lukemiseen tarkistuspisteitä. asennettu import tensorflow as tf class Net(tf.keras.Model): """A simple linear model.""" def __init__(self): super(Net, self).__init__() self.l1 = tf.keras.layers.Dense(5) def call(self, x): return self.l1(x) net = Net() Säästämällä APIs koulutus TF.Keräjä TF.Keräjä Katso myös Ohjeita säästämiseen ja palauttamiseen. tf.keras Tallenna TensorFlow tarkistuspiste. tf.keras.Model.save_weights net.save_weights('easy_checkpoint') Kirjoittaminen Checkpoints TensorFlow-mallin pysyvä tila tallennetaan Nämä voidaan rakentaa suoraan, mutta ne luodaan usein korkean tason API: n kautta, kuten tai . tf.Variable tf.keras.layers tf.keras.Model Helpoin tapa hallita muuttujia on liittää ne Python-objekteihin ja sitten viitata niihin. Alaluokkaiset ja ja Seuraava esimerkki rakentaa yksinkertaisen lineaarisen mallin ja kirjoittaa sitten tarkistuspisteet, jotka sisältävät arvot kaikista mallin muuttujista. tf.train.Checkpoint tf.keras.layers.Layer tf.keras.Model Voit tallentaa mallin tarkastuspisteen helposti . Model.save_weights Manuaalinen tarkistus asennettu Auttaa osoittamaan kaikki ominaisuudet , määrittele lelun tietokokonaisuus ja optimointivaihe: tf.train.Checkpoint def toy_dataset(): inputs = tf.range(10.)[:, None] labels = inputs * 5. + tf.range(5.)[None, :] return tf.data.Dataset.from_tensor_slices( dict(x=inputs, y=labels)).repeat().batch(2) def train_step(net, example, optimizer): """Trains `net` on `example` using `optimizer`.""" with tf.GradientTape() as tape: output = net(example['x']) loss = tf.reduce_mean(tf.abs(output - example['y'])) variables = net.trainable_variables gradients = tape.gradient(loss, variables) optimizer.apply_gradients(zip(gradients, variables)) return loss Luo tarkistuskohteet Käytä a luoda manuaalisesti valvontapiste, jossa kohteet, jotka haluat valvontapiste asetetaan ominaisuuksiksi objektissa. tf.train.Checkpoint A on Se voi myös olla hyödyllinen useiden tarkastuspisteiden hallintaan. tf.train.CheckpointManager opt = tf.keras.optimizers.Adam(0.1) dataset = toy_dataset() iterator = iter(dataset) ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator) manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3) Lentokenttä ja mallin tarkastus Seuraavassa koulutuskierros luo mallin ja optimoijan esimerkin, ja sitten kokoaa ne yhteen Se kutsuu koulutusvaiheen ympyrään jokaisesta datapaketista ja kirjoittaa määräajoin tarkistuspisteitä levylle. tf.train.Checkpoint def train_and_checkpoint(net, manager): ckpt.restore(manager.latest_checkpoint) if manager.latest_checkpoint: print("Restored from {}".format(manager.latest_checkpoint)) else: print("Initializing from scratch.") for _ in range(50): example = next(iterator) loss = train_step(net, example, opt) ckpt.step.assign_add(1) if int(ckpt.step) % 10 == 0: save_path = manager.save() print("Saved checkpoint for step {}: {}".format(int(ckpt.step), save_path)) print("loss {:1.2f}".format(loss.numpy())) train_and_checkpoint(net, manager) Palauta ja jatka koulutusta Ensimmäisen koulutusjakson jälkeen voit siirtää uuden mallin ja johtajan, mutta poimia koulutuksen täsmälleen siellä, missä lopetit: opt = tf.keras.optimizers.Adam(0.1) net = Net() dataset = toy_dataset() iterator = iter(dataset) ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator) manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3) train_and_checkpoint(net, manager) Sillä objekti poistaa vanhat valvontapisteet. Ylhäällä se on määritetty säilyttämään vain kolme viimeisintä valvontapistettä. tf.train.CheckpointManager print(manager.checkpoints) # List the three remaining checkpoints Nämä reitit ovat mm. , eivät ole tiedostoja levyllä. Sen sijaan ne ovat etuliitteitä tiedosto ja yksi tai useampi tiedosto, jotka sisältävät muuttuvia arvoja.Nämä esikuvat on ryhmitelty yhteen Tiedosto ( 3) Missä on Se pelastaa valtion. './tf_ckpts/ckpt-10' index checkpoint './tf_ckpts/checkpoint' CheckpointManager ls ./tf_ckpts Mekaaninen kuorma TensorFlow vastaa muuttujia tarkistettuihin arvoihin ylittämällä suunnatun kaavion, jossa on nimettyjä reunoja, alkaen ladattavasta objektista. Sisällä ja käytä avainsanojen argumenttien nimiä, kuten Sisällä . "l1" self.l1 = tf.keras.layers.Dense(5) tf.train.Checkpoint "step" tf.train.Checkpoint(step=...) Yllä olevan esimerkin riippuvuusgrafiikka näyttää tältä: Optimisaattori on punainen, säännölliset muuttujat ovat sinisiä ja optimointilaatikon muuttujat ovat oransseja. Se on mustassa. tf.train.Checkpoint Slot-muuttujat ovat osa optimoijan tilaa, mutta ne luodaan tietylle muuttujalle. Edellä mainitut reunat vastaavat momenttia, jonka Adam-optimoija seuraa jokaisesta muuttujasta. Slot-muuttujat tallennetaan tarkistuspisteeseen vain, jos sekä muuttuja että optimoija tallennettaisiin, joten reunat tallennetaan. 'm' Soittaminen on a objekti ajoittaa pyydetyt palautukset, palauttaen muuttuvat arvot heti, kun vastaava polku on Voit esimerkiksi ladata vain edellä määritellyn mallin puolueen rakentamalla uudelleen yhden polun siihen verkon ja kerroksen kautta. restore tf.train.Checkpoint Checkpoint to_restore = tf.Variable(tf.zeros([5])) print(to_restore.numpy()) # All zeros fake_layer = tf.train.Checkpoint(bias=to_restore) fake_net = tf.train.Checkpoint(l1=fake_layer) new_root = tf.train.Checkpoint(net=fake_net) status = new_root.restore(tf.train.latest_checkpoint('./tf_ckpts/')) print(to_restore.numpy()) # This gets the restored value. Näiden uusien objektien riippuvuusgrafiikka on paljon pienempi algrafiikka suuremmasta tarkistuspisteestä, jonka kirjoitit edellä. Käytä tarkastuspisteiden numeroita. tf.train.Checkpoint palauttaa tilanobjektin, jossa on valinnaisia väitteitä. Se on palautettu, joten Käy läpi restore Checkpoint status.assert_existing_objects_matched status.assert_existing_objects_matched() Valvontapisteessä on monia objekteja, jotka eivät ole vastanneet, mukaan lukien kerroksen ydin ja optimoijan muuttujat. Se kulkee vain, jos tarkistuspiste ja ohjelma täsmälleen vastaavat, ja heittäisi poikkeuksen täällä. status.assert_consumed Viivästyneet kunnostukset TensorFlow-objektit voivat lykätä muuttujien luomista ensimmäiseen kutsuunsa, kun syöttömuodot ovat käytettävissä. kerroksen ydin riippuu sekä kerroksen syöttö- että ulostulomuodoista, joten rakentajan argumenttina vaadittu ulostulomuoto ei ole riittävästi tietoa muuttujan luomiseen. myös lukee muuttujan arvon, palautus on tapahduttava muuttujan luomisen ja sen ensimmäisen käytön välillä. Layer Dense Layer Tämän idiootin tukeminen, poistaa palautukset, joilla ei vielä ole vastaavaa muuttujaa. tf.train.Checkpoint deferred_restore = tf.Variable(tf.zeros([1, 5])) print(deferred_restore.numpy()) # Not restored; still zeros fake_layer.kernel = deferred_restore print(deferred_restore.numpy()) # Restored Manuaalinen tarkastus tarkastuspisteissä Takaisin a joka antaa alemman tason pääsyn tarkistuspisteen sisältöön. Se sisältää kartoituksia kunkin muuttujan avaimesta tarkistuspisteen kunkin muuttujan muotoon ja dtyyppiin. Muuttujan avain on sen kohteen polku, kuten yllä olevissa kaaviossa. tf.train.load_checkpoint CheckpointReader Huomautus: Tarkistuspisteeseen ei ole korkeamman tason rakennetta. Se tietää vain muuttujien polut ja arvot, eikä sillä ole konseptia malleista, kerroksista tai siitä, miten ne on yhdistetty. Tarkastuspisteelle ei ole korkeampaa tasorakennetta.Se tietää vain muuttujien polut ja arvot, eikä sillä ole käsitystä ja Tai miten ne liittyvät toisiinsa. Note: models layers reader = tf.train.load_checkpoint('./tf_ckpts/') shape_from_key = reader.get_variable_to_shape_map() dtype_from_key = reader.get_variable_to_dtype_map() sorted(shape_from_key.keys()) Joten jos olet kiinnostunut arvosta Voit saada arvon seuraavalla koodilla: net.l1.kernel key = 'net/l1/kernel/.ATTRIBUTES/VARIABLE_VALUE' print("Shape:", shape_from_key[key]) print("Dtype:", dtype_from_key[key].name) Se tarjoaa myös a menetelmä, jonka avulla voit tarkistaa muuttujan arvon: get_tensor reader.get_tensor(key) Objektin seuranta Tarkistuspisteet tallentavat ja palauttavat objekteja "seurantamalla" mitä tahansa muuttujaa tai jäljitettävää objektia, joka on asetettu johonkin sen ominaisuuksista. tf.Variable Samankaltaisia ominaisuuksia kuin , luetteloiden ja sanakirjojen määrittäminen ominaisuuksiin seuraa niiden sisältöä. self.l1 = tf.keras.layers.Dense(5) save = tf.train.Checkpoint() save.listed = [tf.Variable(1.)] save.listed.append(tf.Variable(2.)) save.mapped = {'one': save.listed[0]} save.mapped['two'] = save.listed[1] save_path = save.save('./tf_list_example') restore = tf.train.Checkpoint() v2 = tf.Variable(0.) assert 0. == v2.numpy() # Not restored yet restore.mapped = {'two': v2} restore.restore(save_path) assert 2. == v2.numpy() Saatat huomata luetteloiden ja sanakirjojen pakkausobjekteja. Nämä pakkaukset ovat tarkistettavissa olevia versioita taustalla olevista tietorakenteista. Aivan kuten ominaisuuspohjainen lataus, nämä pakkaukset palauttavat muuttujan arvon heti, kun se lisätään säiliöön. restore.listed = [] print(restore.listed) # ListWrapper([]) v1 = tf.Variable(0.) restore.listed.append(v1) # Restores v1, from restore() in the previous cell assert 1. == v1.numpy() Jäljitettäviä kohteita ovat ja ja sen alaluokat (esim. ja Lisätietoja python-konteineristä: tf.train.Checkpoint tf.Module keras.layers.Layer keras.Model dikt (ja collections.OrderedDict) luettelo tuple (ja collections.namedtuple, typing.NamedTuple) Muita kontteja ovat mukaan lukien : not supported Vapaaehtoinen.defaultdict Setä Kaikki muut Python-objektit ovat mukaan lukien : ignored sisällä Stringissä purjehdus Yhteenveto TensorFlow-objektit tarjoavat helpon automaattisen mekanismin käytettyjen muuttujien arvojen tallentamiseen ja palauttamiseen. Alun perin julkaistu TensorFlow-verkkosivustolla, tämä artikkeli ilmestyy täällä uuden otsikon alla ja on lisensoitu CC BY 4.0. Alun perin julkaistu TensorFlow-verkkosivustolla, tämä artikkeli ilmestyy täällä uuden otsikon alla ja on lisensoitu CC BY 4.0. TensorFlowMuokkaa TensorFlowMuokkaa