程序師世界是廣大編程愛好者互助、分享、學習的平台,程序師世界有你更精彩!
首頁
編程語言
C語言|JAVA編程
Python編程
網頁編程
ASP編程|PHP編程
JSP編程
數據庫知識
MYSQL數據庫|SqlServer數據庫
Oracle數據庫|DB2數據庫
您现在的位置: 程式師世界 >> 編程語言 >  >> 更多編程語言 >> Python

dsx-rl中遇到的python函數的筆記

編輯:Python

1.zip()函數
zip() 函數用於將可迭代的對象作為參數,將對象中對應的元素打包成一個個元組,然後返回由這些元組組成的對象,這樣做的好處是節約了不少的內存。

我們可以使用 list() 轉換來輸出列錶。

如果各個迭代器的元素個數不一致,則返回列錶長度與最短的對象相同,利用 * 號操作符配合zip函數,可以將元組解壓為列錶。

>>> a = [1,2,3]
>>> b = [4,5,6]
>>> c = [4,5,6,7,8]
>>> zipped = zip(a,b) # 返回一個對象
>>> zipped
<zip object at 0x103abc288>
>>> list(zipped) # list() 轉換為列錶
[(1, 4), (2, 5), (3, 6)]
>>> list(zip(a,c)) # 元素個數與最短的列錶一致
[(1, 4), (2, 5), (3, 6)]
>>> a1, a2 = zip(*zip(a,b)) # 與 zip 相反,zip(*) 可理解為解壓,返回二維矩陣式
>>> list(a1)
[1, 2, 3]
>>> list(a2)
[4, 5, 6]
>>>

https://www.runoob.com/python3/python3-func-zip.html

2.np.random.random()函數
當無參數傳入時返回一個0-1的隨機數
當傳入參數則返回shape為參數的0-1的隨機數的數組

3.numpy.random.randint()函數

numpy.random.randint(low, high=None, size=None, dtype='l')

函數的作用是,返回一個隨機整型數,範圍從低(包括)到高(不包括),即[low, high)。
如果沒有寫參數high的值,則返回[0,low)的值。

>>> np.random.randint(2, size=10)
array([1, 0, 0, 0, 1, 1, 0, 0, 1, 0])
>>> np.random.randint(1, size=10)
array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
>>> np.random.randint(5, size=(2, 4))
array([[4, 0, 2, 1],
[3, 2, 2, 0]])
>>>np.random.randint(2, high=10, size=(2,3))
array([[6, 8, 7],
[2, 5, 2]])

https://blog.csdn.net/u011851421/article/details/83544853
4.gather函數

gather函數的功能可以解釋為根據 index 參數(即是索引)返回數組裏面對應比特置的值
這裏的b.gather()寫法和torch.gather(b)的寫法都可以,重點是兩個參數,dim和index

低維的理解方式
dim=0錶示按行來索引,也就是說index的值錶示的是第幾行
dim=1錶示按列來索引,也就是指index的值錶示的是第幾列
5.torch.distributions.Categorical

probs = torch.FloatTensor([0.9,0.2])
ac = torch.distributions.Categorical(probs)
print(ac)
for _ in range(5):
print(ac.sample())

其作用是創建以參數probs為標准的類別分布,樣本是來自“0,…,K-1”的整數,K是probs參數的長度。也就是說,按照probs的概率,在相應的比特置進行采樣,采樣返回的是該比特置的整數索引。

再看一下在rl中依據策略網絡選擇動作:

 def take_action(self, state): # 根據動作概率分布隨機采樣
state = torch.tensor([state], dtype=torch.float).to(self.device) # 1*4
probs = self.policy_net(state) # 1*2
action_dist = torch.distributions.Categorical(probs)
action = action_dist.sample()
return action.item()


  1. 上一篇文章:
  2. 下一篇文章:
Copyright © 程式師世界 All Rights Reserved