程序師世界是廣大編程愛好者互助、分享、學習的平台,程序師世界有你更精彩!
首頁
編程語言
C語言|JAVA編程
Python編程
網頁編程
ASP編程|PHP編程
JSP編程
數據庫知識
MYSQL數據庫|SqlServer數據庫
Oracle數據庫|DB2數據庫
 程式師世界 >> 編程語言 >> C語言 >> C++ >> C++入門知識 >> HDU 3450 Counting Sequences(DP + 樹狀數組)

HDU 3450 Counting Sequences(DP + 樹狀數組)

編輯:C++入門知識

HDU 3450 Counting Sequences(DP + 樹狀數組)


題目鏈接:點擊打開鏈接

題目大意: 統計滿足相鄰兩個數之差不超過d的子序列個數。

我們不難想到一個O(n^2)的DP算法 : 對於每一個i, d[i]表示 以i結尾的子序列個數。 那麼它將轉移到所有滿足(j >= 1 && j < i && abs(a[j]-a[i])<=d)的d[j] 。

但是由於n太大了, 這樣顯然會超時, 那麼我們來想想如何優化這個算法: 可以發現, 對於每個d[i], 其累加的部分是一個(a[i] - d, a[i] + d)的范圍內的且在i之前出現過的所有d[j]。 這恰恰符合樹狀數組的特點: 求連續和、單點更新 。

所以我們不難想到每次更新完d[i] 之後, 在a[i]這個位置上增加d[i]+1。 但是該題沒有給a[i]的數據范圍, 得到WA之後證明, 數據應該很大, 數組開不下。 那麼我們只需要離散化一下, 將數據一一映射到1~n的范圍內就好了。然後用二分查找找到映射後的代碼,用樹狀數組求解。

離散化的一個比較簡單易行的方法就是用另一個數組b復制a,然後對b進行排序去重,那麼此時b的下標就是對數組a映射後的值,用二分查找可以很容易的找到。

復雜度O(n*logn)。

細節參見代碼:

 

#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#define Max(a,b) ((a)>(b)?(a):(b))
#define Min(a,b) ((a)<(b)?(a):(b))
using namespace std;
typedef long long ll;
const double PI = acos(-1.0);
const double eps = 1e-6;
const int mod = 9901;
const int INF = 1000000000;
const int maxn = 100000 + 10;
ll T,n,m,e,d[maxn],bit[maxn],kase = 0,a[maxn],b[maxn];
ll sum(ll x) {
    ll ans = 0;
    while(x > 0) {
        ans = (ans + bit[x])%mod;
        x -= x & -x;
    }
    return ans;
}
void add(ll x, ll dd) {
    dd %= mod;
    while(x <= n) {
        bit[x] = (bit[x] + dd)%mod;
        x += x & -x;
    }
}
int main() {
    while(~scanf("%I64d%I64d",&n,&e)) {
        memset(bit, 0, (n+1)*sizeof(ll));
        ll maxv = 0;
        for(int i=1;i<=n;i++) {
            scanf("%I64d",&a[i]);
            b[i] = a[i];
            maxv = max(maxv, a[i]);
        }
        sort(b+1, b+n+1);
        int len = unique(b+1,b+n+1) - b - 1;
        for(int i=1;i<=n;i++) {
            ll l = lower_bound(b+1,b+len+1,a[i]-e) - b - 1;
            ll r = lower_bound(b+1,b+len+1,a[i]+e) - b;
            ll v = lower_bound(b+1,b+len+1,a[i]) - b;
            if(r > len || b[r] > a[i]+e) --r;
            d[i] = (sum(r) - sum(l) + mod) % mod;
            add(v, d[i]+1);
        }
        ll ans = 0;
        for(int i=1;i<=n;i++) {
            ans = (ans + d[i])%mod;
        }
        printf("%I64d\n",ans);
    }
    return 0;
}

  1. 上一頁:
  2. 下一頁:
Copyright © 程式師世界 All Rights Reserved