Pytorch+Catalyst+ResNet事前学習済モデルをTPUで動かす

筆者のPytorchの学習状況

KaggleだとPytorchで書かれたコードが多くいい加減やらねば・・・と思い最近UdemyでPytorchの講座を受けていました。

Kerasから入った身としてはKerasよりコード量が多いことがすこしネックでした。

 

モチベーション

Pytorch用のフレームワークを使いながらTPUを使えたらいいのにと思っていたところ、出たので学習済ResNet18をロードしてCIFAR10の学習をしてみました。 

 

実装

以下で公開しています。

colab.research.google.com

github.com

 

今後に期待すること

catalystのバージョンがstabelでないため or Pytochの挙動をいまいちわかっていないためか、colabのランタイムをGPUにした場合、

device = utils.get_device() だと'cpu'しか認識されませんでした。

Pytoch+TPUの実装はKaggleで公開されていたり、こちらの記事 で紹介されていましたが、torch_xlaのコード改変が必要で少し面倒です。

かったるい事は嫌いなタチなんで、今後はこちらの開発が加速していくことに期待しています。

Pytorch+TPU+(KaggleNotebook or GCP)の構成が普段使いできる展望が見えてきたのが非常にうれしいです。

 

 

 

 

 

ABC082 をPythonで解く

atcoder.jp

A~Cを解いた。

A - Round Up the Mean

import numpy as np
A,B = map(int,input().split())
ans = int(np.ceil((A+B)/2))
print(ans)

 

B - Two Anagrams

s = input()
t = input()

sl = list(s)
sl.sort()
tl = list(t)
tl.sort()
tl=tl[::-1]
S="".join(sl)
T="".join(tl)
if S<T:
    ans ='Yes'
else:
    ans = 'No'
print(ans)

C - Good Sequence

N = int(input())
A = [int(i) for i in input().split()]
from collections import Counter
ans = 0
for k,v in Counter(A).items():
    if v!=k and k>v:
        ans+=v
    elif v!=k and k<v:
        ans+=v-k
print(ans)

ABC166 をPythonで解く

atcoder.jp

A~Dを解いた。

A - A?C

S = input()
if S=="ABC":
    print("ARC")
else:
    print("ABC")

 

B - Trick or Treat

N,K = map(int,input().split())

O = [0]*N

for k in range(K):
    d = int(input())
    A = [int(i) for i in input().split()]
    for a in A:
        O[a-1]+=1
print(O.count(0))

C - Peaks

# ABC 166 C

N,M =map(int,input().split())
H= [int(j) for j in input().split()]
AB = []
for _ in range(M):
    a,b = map(int,input().split())
    AB.append([a,b])
    AB.append([b,a])
AB.sort(key=lambda s:s[0])
cnt =0
last_a = AB[0][0]
a_cnt = [0] *N
h_cnt = [0] *N
for a,b in AB:
    a-=1
    b-=1
    a_cnt[a] =a_cnt[a]+1
    if H[a] >H[b]:
        h_cnt[a]=h_cnt[a]+1
    
for i,a in enumerate(a_cnt):
    if h_cnt[i]==a:
        cnt+=1
print(cnt)  

D - I hate Factorization

X =int(input())
def jug(n):
    if n>0:
        return 1
    else:
        return -1

ALL = [i**5 for i in range(-200,200+1)]
for a in ALL:
    for b in ALL:
        if a-b ==X:
            ans =[jug(a)*int(abs(a)**(1.0/5)),jug(b)*int(abs(b)**(0.2))]
            print(ans[0],ans[1])
            exit()

ABC151 をPythonで解く

atcoder.jp

A~Cを解いた。

A

C =input()
ans = ord(C)
ans = chr(ans+1)
print(ans)

 

B

N,K,M = map(int,input().split())
A = [int(i) for i in input().split()]
ans = N*M - sum(A)
if ans >K:
    print(-1)
    exit()
print(max(ans,0))

C

N,M = map(int,input().split())
PS =[""]*N
PENALTY = 0
AC = 0
for m in range(M):
    p,s = map(str,input().split())
    p=int(p)-1
    if PS[p]!="AC" and s=="WA":
        PENALTY+=1
    elif PS[p]!="AC" and s=="AC":
        AC+=1
        PS[p]="AC"
print(AC,PENALTY)

ABC114 をPythonで解く

atcoder.jp

A~Cを解いた。

A

X =int(input())
if X in [7,5,3]:
    ans ="YES"
else:
    ans ="NO"
print(ans)

 

B

m=min(m,xxx)でコード量削減

S =input()
m = 753 
for i in range(len(S)-2):
    tmp=(S[i:i+3])
    m=min(abs(753-int(tmp)),m)
print(m)

C

さすがに今日はjoinで結合すぐ思い出せた。総当たりで組み合わせを作って7,5,3が一回以上出てくる場合のみに絞る。あとは大小を比較する。

from itertools import product
N = int(input())
STR_N = str(N)
all_pat =[]
for i in range(3,len(STR_N)+1):
    pattern = list(product(["7","5","3"],repeat=i))
    for p in pattern:
        if len(set(p))==3:
            all_pat.append(int(''.join(p)))
all_pat = list(set(all_pat))
all_pat.sort()
cnt = 0
for a in all_pat:
    if a<=N:
        cnt+=1
print(cnt)

ABC066 をPythonで解く

atcoder.jp

A~Cを解いた。

A

A = [int(a) for a in input().split()]
ans = sum(A)-max(A)
print(ans)

 

B

いままではif文で処理していた最大値部分をm = max(m,xxx))と書いている人がいたので真似してみた。当たり前と言われればそれまでだがスマートな書き方だと思う。

S =input()
m = 0
for i,s in enumerate(S):
    tmp = S[:-1-i]
    tmp1 = tmp[:len(tmp)//2]
    tmp2 = tmp[len(tmp)//2:]
    if len(tmp)%2==0 and(tmp1==tmp2):
        m = max(m,len(tmp))
print(m)

C

listに対してinsertとappendを行うとTLE。
初めてdequeを使った。listの中身を1行で表示するのはjoinを使うと一発だが毎回検索しているのでそろそろ覚えたい。

N = int(input())
A = input().split()
from collections import deque

tmp = deque([])
for i,a in enumerate(A):
    if i %2!=0:
        tmp.appendleft(a)
    else:
        tmp.append(a)
if N%2!=0:
    tmp.reverse()
print(' '.join(tmp))

ABC067 をPythonで解く

atcoder.jp

A~Cを解いた。

A

A,B =map(int,input().split())

if A%3==0 or B%3==0 or (A+B)%3==0:
    ans = "Possible"
else:
    ans ="Impossible"
print(ans)

 

B

N,K = map(int,input().split())
L =[int(i) for i in input().split()]
L.sort(reverse=True)
ans = sum(L[:K])
print(ans)

C

累積和問題。
内包表記内にsum(A)を書いていると毎回計算させることとなりTLEしてしまう。 total変数に格納して演算を一度にする。これだけで処理時間が8分の1になる。

from itertools import accumulate
N = int(input())
A = [int(i) for i in input().split()]
tmp = list(accumulate(A))[:len(A)-1]
total  = sum(A)
ans = min([abs(total-2*a) for a in tmp])
print(ans)